diff --git a/LICENSE b/LICENSE index 7950dd6ceb6db..66a2e8f132953 100644 --- a/LICENSE +++ b/LICENSE @@ -249,11 +249,11 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (Interpreter classes (all .scala files in repl/src/main/scala except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala), and for SerializableMapWrapper in JavaUtils.scala) - (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.7 - http://www.scala-lang.org/) - (BSD-like) Scalap (org.scala-lang:scalap:2.11.7 - http://www.scala-lang.org/) + (BSD-like) Scala Actors library (org.scala-lang:scala-actors:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-compiler:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Compiler (org.scala-lang:scala-reflect:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scala Library (org.scala-lang:scala-library:2.11.8 - http://www.scala-lang.org/) + (BSD-like) Scalap (org.scala-lang:scalap:2.11.8 - http://www.scala-lang.org/) (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/R/README.md b/R/README.md index 4c40c5963db70..1152b1e8e5f9f 100644 --- a/R/README.md +++ b/R/README.md @@ -66,11 +66,7 @@ To run one of them, use `./bin/spark-submit `. For example: ```bash ./bin/spark-submit examples/src/main/r/dataframe.R ``` -You can also run the unit tests for SparkR by running. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: -```bash -R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' -./R/run-tests.sh -``` +You can run R unit tests by following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests). ### Running on YARN diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 9ca7e58e20cd2..124bc631be9cd 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -34,10 +34,9 @@ To run the SparkR unit tests on Windows, the following steps are required —ass 4. Set the environment variable `HADOOP_HOME` to the full path to the newly created `hadoop` directory. -5. Run unit tests for SparkR by running the command below. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: +5. Run unit tests for SparkR by running the command below. You need to install the needed packages following the instructions under [Running R Tests](http://spark.apache.org/docs/latest/building-spark.html#running-r-tests) first: ``` - R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R ``` diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index f12f8c275a989..18b2db69db8f1 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -6,3 +6,4 @@ ^README\.Rmd$ ^src-native$ ^html$ +^tests/fulltests/* diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 879c1f80f2c5d..4ac45fc98d5e9 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package -Version: 2.2.0 +Version: 2.2.1 Title: R Frontend for Apache Spark -Description: The SparkR package provides an R Frontend for Apache Spark. +Description: Provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ca45c6f9b0a96..44e39c4abb472 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -122,6 +122,7 @@ exportMethods("arrange", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 88a138fd8eb1f..3859fa8631b38 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -591,7 +591,7 @@ setMethod("cache", #' #' Persist this SparkDataFrame with the specified storage level. For details of the #' supported storage levels, refer to -#' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. +#' \url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}. #' #' @param x the SparkDataFrame to persist. #' @param newLevel storage level chosen for the persistance. See available options in @@ -2642,6 +2642,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. +#' Also as standard in SQL, this function resolves columns by position (not by name). #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -3642,3 +3643,33 @@ setMethod("checkpoint", df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) dataFrame(df) }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 7ad3993e9ecbc..15ca212acf87f 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -227,7 +227,7 @@ setMethod("cacheRDD", #' #' Persist this RDD with the specified storage level. For details of the #' supported storage levels, refer to -#'\url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. +#'\url{http://spark.apache.org/docs/latest/rdd-programming-guide.html#rdd-persistence}. #' #' @param x The RDD to persist #' @param newLevel The new storage level to be assigned diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index f5c3a749fe0a1..e3528bc7c3135 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -334,7 +334,7 @@ setMethod("toDF", signature(x = "RDD"), #' #' Loads a JSON file, returning the result as a SparkDataFrame #' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' ) is supported. For JSON (one record per file), set a named property \code{multiLine} to #' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' @@ -348,7 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) -#' df <- read.json(path, wholeFile = TRUE) +#' df <- read.json(path, multiLine = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -598,7 +598,7 @@ tableToDF <- function(tableName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, multiLine = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } #' @name read.df diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 945676c7f10b3..f8ae5526bc72a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -572,6 +572,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) @@ -1469,7 +1473,7 @@ setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") #' @rdname awaitTermination #' @export -setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") }) +setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive #' @export diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 4ca7aa664e023..492dee68e164d 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -267,10 +267,14 @@ hadoopVersionName <- function(hadoopVersion) { # The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and # adapt to Spark context sparkCachePath <- function() { - if (.Platform$OS.type == "windows") { + if (is_windows()) { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { - stop(paste("%LOCALAPPDATA% not found.", + message("%LOCALAPPDATA% not found. Falling back to %USERPROFILE%.") + winAppPath <- Sys.getenv("USERPROFILE", unset = NA) + } + if (is.na(winAppPath)) { + stop(paste("%LOCALAPPDATA% and %USERPROFILE% not found.", "Please define the environment variable", "or restart and enter an installation path in localDir.")) } else { diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 4db9cc30fb0c1..bdcc0818d139d 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -46,22 +46,25 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj" #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) -#' linear SVM Model +#' Linear SVM Model #' -#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package +#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package. +#' Currently only supports binary classification model with linear kernel. #' Users can print, make predictions on the produced model and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam The regularization parameter. +#' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. #' @param standardization Whether to standardize the training features before fitting the model. The coefficients #' of models will be always returned on the original scale, so it will be transparent for #' users. Note that with/without standardization, the models should be always converged #' to the same solution when no regularization is applied. -#' @param threshold The threshold in binary classification, in range [0, 1]. +#' @param threshold The threshold in binary classification applied to the linear model prediction. +#' This threshold can be any real number, where Inf will make all predictions 0.0 +#' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. @@ -111,10 +114,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu new("LinearSVCModel", jobj = jobj) }) -# Predicted values based on an LinearSVCModel model +# Predicted values based on a LinearSVCModel model #' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on an LinearSVCModel. +#' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method #' @export @@ -124,13 +127,12 @@ setMethod("predict", signature(object = "LinearSVCModel"), predict_internal(object, newData) }) -# Get the summary of an LinearSVCModel +# Get the summary of a LinearSVCModel -#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}. +#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}. #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list includes \code{coefficients} (coefficients of the fitted model), -#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes), -#' \code{numFeatures} (number of features). +#' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method #' @export @@ -138,22 +140,14 @@ setMethod("predict", signature(object = "LinearSVCModel"), setMethod("summary", signature(object = "LinearSVCModel"), function(object) { jobj <- object@jobj - features <- callJMethod(jobj, "features") - labels <- callJMethod(jobj, "labels") - coefficients <- callJMethod(jobj, "coefficients") - nCol <- length(coefficients) / length(features) - coefficients <- matrix(unlist(coefficients), ncol = nCol) - intercept <- callJMethod(jobj, "intercept") + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) numClasses <- callJMethod(jobj, "numClasses") numFeatures <- callJMethod(jobj, "numFeatures") - if (nCol == 1) { - colnames(coefficients) <- c("Estimate") - } else { - colnames(coefficients) <- unlist(labels) - } - rownames(coefficients) <- unlist(features) - list(coefficients = coefficients, intercept = intercept, - numClasses = numClasses, numFeatures = numFeatures) + list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures) }) # Save fitted LinearSVCModel to the input path diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index e353d2dd07c3d..8390bd5e6de72 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -169,8 +169,10 @@ setMethod("isActive", #' immediately. #' #' @param x a StreamingQuery. -#' @param timeout time to wait in milliseconds -#' @return TRUE if query has terminated within the timeout period. +#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} +#' is called or an error has occured. +#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not +#' specified. #' @rdname awaitTermination #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method @@ -182,8 +184,12 @@ setMethod("isActive", #' @note experimental setMethod("awaitTermination", signature(x = "StreamingQuery"), - function(x, timeout) { - handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + function(x, timeout = NULL) { + if (is.null(timeout)) { + invisible(handledCallJMethod(x@ssq, "awaitTermination")) + } else { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + } }) #' stopQuery diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fbc89e98847bf..7225da9ca896f 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -899,3 +899,15 @@ basenameSansExtFromUrl <- function(url) { isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } + +is_windows <- function() { + .Platform$OS.type == "windows" +} + +hadoop_home_set <- function() { + !identical(Sys.getenv("HADOOP_HOME"), "") +} + +windows_with_hadoop <- function() { + !is_windows() || hadoop_home_set() +} diff --git a/R/pkg/inst/tests/testthat/test_basic.R b/R/pkg/inst/tests/testthat/test_basic.R new file mode 100644 index 0000000000000..de47162d5325f --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_basic.R @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("basic tests for CRAN") + +test_that("create DataFrame from list or data.frame", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + i <- 4 + df <- createDataFrame(data.frame(dummy = 1:i)) + expect_equal(count(df), i) + + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(l) + expect_equal(columns(df), c("a", "b")) + + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + df <- createDataFrame(ldf) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + mtcarsdf <- createDataFrame(mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) + + bytes <- as.raw(c(1, 2, 3)) + df <- createDataFrame(list(list(bytes))) + expect_equal(collect(df)[[1]][[1]], bytes) + + sparkR.session.stop() +}) + +test_that("spark.glm and predict", { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + + sparkR.session.stop() +}) diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/tests/fulltests/jarTest.R similarity index 96% rename from R/pkg/inst/tests/testthat/jarTest.R rename to R/pkg/tests/fulltests/jarTest.R index c9615c8d4faf6..e2241e03b55f8 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/tests/fulltests/jarTest.R @@ -16,7 +16,7 @@ # library(SparkR) -sc <- sparkR.session() +sc <- sparkR.session(master = "local[1]") helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass", "helloWorld", diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/tests/fulltests/packageInAJarTest.R similarity index 96% rename from R/pkg/inst/tests/testthat/packageInAJarTest.R rename to R/pkg/tests/fulltests/packageInAJarTest.R index 4bc935c79eb0f..ac706261999fb 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/tests/fulltests/packageInAJarTest.R @@ -17,7 +17,7 @@ library(SparkR) library(sparkPackageTest) -sparkR.session() +sparkR.session(master = "local[1]") run1 <- myfunc(5L) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_Serde.R rename to R/pkg/tests/fulltests/test_Serde.R index b5f6f1b54fa85..6bbd201bf1d82 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/tests/fulltests/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/tests/fulltests/test_Windows.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_Windows.R rename to R/pkg/tests/fulltests/test_Windows.R index 1d777ddb286df..b2ec6c67311db 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/tests/fulltests/test_Windows.R @@ -17,7 +17,7 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { - if (.Platform$OS.type != "windows") { + if (!is_windows()) { skip("This test is only for Windows, skipped") } diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/tests/fulltests/test_binaryFile.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_binaryFile.R rename to R/pkg/tests/fulltests/test_binaryFile.R index b5c279e3156e5..758b174b8787c 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/tests/fulltests/test_binaryFile.R @@ -18,7 +18,7 @@ context("functions on binary files") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_binary_function.R rename to R/pkg/tests/fulltests/test_binary_function.R index 59cb2e6204405..442bed509bb1d 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -18,7 +18,7 @@ context("binary functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/tests/fulltests/test_broadcast.R similarity index 95% rename from R/pkg/inst/tests/testthat/test_broadcast.R rename to R/pkg/tests/fulltests/test_broadcast.R index 65f204d096f43..5f74d4960451a 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/tests/fulltests/test_broadcast.R @@ -18,7 +18,7 @@ context("broadcast variables") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/tests/fulltests/test_client.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_client.R rename to R/pkg/tests/fulltests/test_client.R diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/tests/fulltests/test_context.R similarity index 94% rename from R/pkg/inst/tests/testthat/test_context.R rename to R/pkg/tests/fulltests/test_context.R index c847113491113..73b0f5355518d 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/tests/fulltests/test_context.R @@ -56,7 +56,7 @@ test_that("Check masked functions", { test_that("repeatedly starting and stopping SparkR", { for (i in 1:4) { - sc <- suppressWarnings(sparkR.init()) + sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) expect_equal(countRDD(rdd), 20) suppressWarnings(sparkR.stop()) @@ -65,7 +65,7 @@ test_that("repeatedly starting and stopping SparkR", { test_that("repeatedly starting and stopping SparkSession", { for (i in 1:4) { - sparkR.session(enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) df <- createDataFrame(data.frame(dummy = 1:i)) expect_equal(count(df), i) sparkR.session.stop() @@ -73,12 +73,12 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { - sc <- sparkR.sparkContext() # sc should get id 0 + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 sparkR.session.stop() - sc <- sparkR.sparkContext() # sc should get id 0 again + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -96,7 +96,7 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - sc <- sparkR.sparkContext() + sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() @@ -108,7 +108,7 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) @@ -161,14 +161,14 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) doubled <- spark.lapply(1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) sparkR.session.stop() }) test_that("add and get file to be downloaded with Spark job on every node", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/tests/fulltests/test_includePackage.R similarity index 95% rename from R/pkg/inst/tests/testthat/test_includePackage.R rename to R/pkg/tests/fulltests/test_includePackage.R index 563ea298c2dd8..f4ea0d1b5cb27 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/tests/fulltests/test_includePackage.R @@ -18,7 +18,7 @@ context("include R packages") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/tests/fulltests/test_jvm_api.R similarity index 93% rename from R/pkg/inst/tests/testthat/test_jvm_api.R rename to R/pkg/tests/fulltests/test_jvm_api.R index 7348c893d0af3..8b3b4f73de170 100644 --- a/R/pkg/inst/tests/testthat/test_jvm_api.R +++ b/R/pkg/tests/fulltests/test_jvm_api.R @@ -17,7 +17,7 @@ context("JVM API") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("Create and call methods on object", { jarr <- sparkR.newJObject("java.util.ArrayList") diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R similarity index 83% rename from R/pkg/inst/tests/testthat/test_mllib_classification.R rename to R/pkg/tests/fulltests/test_mllib_classification.R index 459254d271a58..726e9d9a20b1c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib classification algorithms, except for tree-based algorithms") # Tests for MLlib classification algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -38,9 +38,8 @@ test_that("spark.svmLinear", { expect_true(class(summary$coefficients[, 1]) == "numeric") coefs <- summary$coefficients[, "Estimate"] - expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) + expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085) expect_true(all(abs(coefs - expected_coefs) < 0.1)) - expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2) # Test prediction with string label prediction <- predict(model, training) @@ -50,15 +49,17 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # Test prediction with numeric label label <- c(0.0, 0.0, 0.0, 1.0, 1.0) @@ -128,15 +129,17 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # R code to reproduce the result. # nolint start @@ -243,19 +246,21 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - - expect_equal(summary2$numOfInputs, 4) - expect_equal(summary2$numOfOutputs, 3) - expect_equal(summary2$layers, c(4, 5, 4, 3)) - expect_equal(length(summary2$weights), 64) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + } # Test default parameter model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) @@ -284,22 +289,11 @@ test_that("spark.mlp", { c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +304,6 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) }) test_that("spark.naiveBayes", { @@ -367,16 +359,18 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - expect_equal(s$apriori, s2$apriori) - expect_equal(s$tables, s2$tables) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + } # Test e1071::naiveBayes if (requireNamespace("e1071", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/tests/fulltests/test_mllib_clustering.R similarity index 79% rename from R/pkg/inst/tests/testthat/test_mllib_clustering.R rename to R/pkg/tests/fulltests/test_mllib_clustering.R index 1661e987b730f..4110e13da4948 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/tests/fulltests/test_mllib_clustering.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib clustering algorithms") # Tests for MLlib clustering algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -53,18 +53,20 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } }) test_that("spark.gaussianMixture", { @@ -125,18 +127,20 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$lambda, stats2$lambda) - expect_equal(unlist(stats$mu), unlist(stats2$mu)) - expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) - expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) + + unlink(modelPath) + } }) test_that("spark.kmeans", { @@ -171,18 +175,20 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } # Test Kmeans on dataset that is sensitive to seed value col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) @@ -236,22 +242,24 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_true(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_equal(vocabulary, stats2$vocabulary) - expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) - expect_equal(logPrior, stats2$logPrior) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_true(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) + expect_equal(logPrior, stats2$logPrior) + + unlink(modelPath) + } }) test_that("spark.lda with text input", { diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/tests/fulltests/test_mllib_fpm.R similarity index 85% rename from R/pkg/inst/tests/testthat/test_mllib_fpm.R rename to R/pkg/tests/fulltests/test_mllib_fpm.R index c38f1133897dd..69dda52f0c279 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/tests/fulltests/test_mllib_fpm.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib frequent pattern mining") # Tests for MLlib frequent pattern mining algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.fpGrowth", { data <- selectExpr(createDataFrame(data.frame(items = c( @@ -62,15 +62,17 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") - write.ml(model, modelPath, overwrite = TRUE) - loaded_model <- read.ml(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) - expect_equivalent( - itemsets, - collect(spark.freqItemsets(loaded_model))) + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) - unlink(modelPath) + unlink(modelPath) + } model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) expect_equal( diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/tests/fulltests/test_mllib_recommendation.R similarity index 59% rename from R/pkg/inst/tests/testthat/test_mllib_recommendation.R rename to R/pkg/tests/fulltests/test_mllib_recommendation.R index 6b1040db93050..4d919c9d746b0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/tests/fulltests/test_mllib_recommendation.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib recommendation algorithms") # Tests for MLlib recommendation algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), @@ -37,29 +37,31 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats2$rating, "score") - userFactors <- collect(stats$userFactors) - itemFactors <- collect(stats$itemFactors) - userFactors2 <- collect(stats2$userFactors) - itemFactors2 <- collect(stats2$itemFactors) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) - orderUser <- order(userFactors$id) - orderUser2 <- order(userFactors2$id) - expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) - expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) - orderItem <- order(itemFactors$id) - orderItem2 <- order(itemFactors2$id) - expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) - expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) - unlink(modelPath) + unlink(modelPath) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/tests/fulltests/test_mllib_regression.R similarity index 95% rename from R/pkg/inst/tests/testthat/test_mllib_regression.R rename to R/pkg/tests/fulltests/test_mllib_regression.R index 3e9ad77198073..82472c92b9965 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/tests/fulltests/test_mllib_regression.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib regression algorithms, except for tree-based algorithms") # Tests for MLlib regression algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("formula of spark.glm", { training <- suppressWarnings(createDataFrame(iris)) @@ -389,14 +389,16 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - expect_equal(result, summary(model2)) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + expect_equal(result, summary(model2)) + + unlink(modelPath) + } }) test_that("spark.survreg", { @@ -438,17 +440,19 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - coefs2 <- as.vector(stats2$coefficients[, 1]) - expect_equal(coefs, coefs2) - expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + } # Test survival::survreg if (requireNamespace("survival", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/tests/fulltests/test_mllib_stat.R similarity index 96% rename from R/pkg/inst/tests/testthat/test_mllib_stat.R rename to R/pkg/tests/fulltests/test_mllib_stat.R index beb148e7702fd..1600833a5d03a 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_stat.R +++ b/R/pkg/tests/fulltests/test_mllib_stat.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib statistics algorithms") # Tests for MLlib statistics algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("spark.kstest", { data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R similarity index 66% rename from R/pkg/inst/tests/testthat/test_mllib_tree.R rename to R/pkg/tests/fulltests/test_mllib_tree.R index e0802a9b02d13..267aa80afdd26 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -20,7 +20,7 @@ library(testthat) context("MLlib tree-based algorithms") # Tests for MLlib tree-based algorithms in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) absoluteSparkPath <- function(x) { sparkHome <- sparkR.conf("spark.home") @@ -44,21 +44,23 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) - modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } # classification # label must be binary - GBTClassifier currently only supports binary classification. @@ -76,17 +78,19 @@ test_that("spark.gbt", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) df <- suppressWarnings(createDataFrame(iris2)) @@ -99,10 +103,12 @@ test_that("spark.gbt", { expect_equal(stats$maxDepth, 5) # spark.gbt classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), - source = "libsvm") - model <- spark.gbt(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 692) + if (windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) + } }) test_that("spark.randomForest", { @@ -136,21 +142,23 @@ test_that("spark.randomForest", { expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) - modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } # classification data <- suppressWarnings(createDataFrame(iris)) @@ -168,17 +176,19 @@ test_that("spark.randomForest", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) + if (windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } # Test numeric response variable labelToIndex <- function(species) { @@ -203,10 +213,12 @@ test_that("spark.randomForest", { expect_equal(length(grep("2.0", predictions)), 50) # spark.randomForest classification can work on libsvm data - data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.randomForest(data, label ~ features, "classification") - expect_equal(summary(model)$numFeatures, 4) + if (windows_with_hadoop()) { + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/tests/fulltests/test_parallelize_collect.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_parallelize_collect.R rename to R/pkg/tests/fulltests/test_parallelize_collect.R index 55972e1ba4693..3d122ccaf448f 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/tests/fulltests/test_parallelize_collect.R @@ -33,7 +33,7 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_rdd.R rename to R/pkg/tests/fulltests/test_rdd.R index b72c801dd958d..6ee1fceffd822 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -18,7 +18,7 @@ context("basic RDD functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -40,8 +40,8 @@ test_that("first on RDD", { }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/tests/fulltests/test_shuffle.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_shuffle.R rename to R/pkg/tests/fulltests/test_shuffle.R index d38efab0fd1df..98300c67c415f 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/tests/fulltests/test_shuffle.R @@ -18,7 +18,7 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/tests/fulltests/test_sparkR.R similarity index 100% rename from R/pkg/inst/tests/testthat/test_sparkR.R rename to R/pkg/tests/fulltests/test_sparkR.R diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R similarity index 92% rename from R/pkg/inst/tests/testthat/test_sparkSQL.R rename to R/pkg/tests/fulltests/test_sparkSQL.R index 6a6c9a809ab13..fc69b4dd24021 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -61,7 +61,11 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- sparkR.session() +sparkSession <- if (windows_with_hadoop()) { + sparkR.session(master = sparkRTestMaster) + } else { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", @@ -96,6 +100,10 @@ mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}} mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) +if (is_windows()) { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) @@ -303,51 +311,53 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "NA,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - # default "header" is false, inferSchema to handle "year" as "int" - df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") - expect_equal(count(df), 4) - expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) - expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), - sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) - - # since "year" is "int", let's skip the NA values - withoutna <- na.omit(df, how = "any", cols = "year") - expect_equal(count(withoutna), 3) - - unlink(csvPath) - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "Empty,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") - expect_equal(count(df2), 4) - withoutna2 <- na.omit(df2, how = "any", cols = "year") - expect_equal(count(withoutna2), 3) - expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) - - # writing csv file - csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") - write.df(df2, path = csvPath2, "csv", header = "true") - df3 <- read.df(csvPath2, "csv", header = "true") - expect_equal(nrow(df3), nrow(df2)) - expect_equal(colnames(df3), colnames(df2)) - csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) - expect_equal(colnames(df3), colnames(csv)) - - unlink(csvPath) - unlink(csvPath2) + if (windows_with_hadoop()) { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + # default "header" is false, inferSchema to handle "year" as "int" + df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expect_equal(count(df), 4) + expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) + expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), + sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) + + # since "year" is "int", let's skip the NA values + withoutna <- na.omit(df, how = "any", cols = "year") + expect_equal(count(withoutna), 3) + + unlink(csvPath) + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "Empty,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") + expect_equal(count(df2), 4) + withoutna2 <- na.omit(df2, how = "any", cols = "year") + expect_equal(count(withoutna2), 3) + expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) + + # writing csv file + csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") + write.df(df2, path = csvPath2, "csv", header = "true") + df3 <- read.df(csvPath2, "csv", header = "true") + expect_equal(nrow(df3), nrow(df2)) + expect_equal(colnames(df3), colnames(df2)) + csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) + expect_equal(colnames(df3), colnames(csv)) + + unlink(csvPath) + unlink(csvPath2) + } }) test_that("Support other types for options", { @@ -570,48 +580,50 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - # Test read.df - df <- read.df(jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test read.df with a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(jsonPath, "json", schema) - expect_is(df1, "SparkDataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Test loadDF - df2 <- loadDF(jsonPath, "json", schema) - expect_is(df2, "SparkDataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) - - # Test read.json - df <- read.json(jsonPath) - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test write.df - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") - write.df(df, jsonPath2, "json", mode = "overwrite") - - # Test write.json - jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") - write.json(df, jsonPath3) - - # Test read.json()/jsonFile() works with multiple input paths - jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) - expect_is(jsonDF1, "SparkDataFrame") - expect_equal(count(jsonDF1), 6) - # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) - expect_is(jsonDF2, "SparkDataFrame") - expect_equal(count(jsonDF2), 6) - - unlink(jsonPath2) - unlink(jsonPath3) + if (windows_with_hadoop()) { + # Test read.df + df <- read.df(jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(jsonPath, "json", schema) + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(jsonPath, "json", schema) + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json + df <- read.json(jsonPath) + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test write.df + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") + write.df(df, jsonPath2, "json", mode = "overwrite") + + # Test write.json + jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) + expect_is(jsonDF1, "SparkDataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) + expect_is(jsonDF2, "SparkDataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) + unlink(jsonPath3) + } }) test_that("read/write json files - compression option", { @@ -642,24 +654,27 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { + count <- count(listTables()) + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") - expect_equal(length(tableNames()), 1) - expect_equal(length(tableNames("default")), 1) + expect_equal(length(tableNames()), count + 1) + expect_equal(length(tableNames("default")), count + 1) + tables <- listTables() - expect_equal(count(tables), 1) + expect_equal(count(tables), count + 1) expect_equal(count(tables()), count(tables)) expect_true("tableName" %in% colnames(tables())) expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) tables <- listTables() - expect_equal(count(tables), 2) + expect_equal(count(tables), count + 2) suppressWarnings(dropTempTable("table1")) expect_true(dropTempView("table2")) tables <- listTables() - expect_equal(count(tables), 0) + expect_equal(count(tables), count + 0) }) test_that( @@ -696,33 +711,35 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - df <- read.df(jsonPath, "json") - write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(parquetPath, "parquet") - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - writeLines(lines, jsonPath2) - df2 <- read.df(jsonPath2, "json") - write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(parquetPath2, "parquet") - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1") - expect_equal(count(sql("select * from table1")), 5) - expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - expect_true(dropTempView("table1")) - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql("select * from table1")), 2) - expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - expect_true(dropTempView("table1")) - - unlink(jsonPath2) - unlink(parquetPath2) + if (windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(parquetPath2, "parquet") + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql("select * from table1")), 5) + expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") + expect_true(dropTempView("table1")) + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql("select * from table1")), 2) + expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") + expect_true(dropTempView("table1")) + + unlink(jsonPath2) + unlink(parquetPath2) + } }) test_that("tableToDF() returns a new DataFrame", { @@ -902,14 +919,16 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - checkpointDir <- file.path(tempdir(), "cproot") - expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) - - setCheckpointDir(checkpointDir) - df <- read.json(jsonPath) - df <- checkpoint(df) - expect_is(df, "SparkDataFrame") - expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + if (windows_with_hadoop()) { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + } }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { @@ -1267,45 +1286,47 @@ test_that("column calculation", { }) test_that("test HiveContext", { - setHiveContext(sc) - - schema <- structType(structField("name", "string"), structField("age", "integer"), - structField("height", "float")) - createTable("people", source = "json", schema = schema) - df <- read.df(jsonPathNa, "json", schema) - insertInto(df, "people") - expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) - sql("DROP TABLE people") - - df <- createTable("json", jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - df2 <- sql("select * from json") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "json2", "json", "append", path = jsonPath2) - df3 <- sql("select * from json2") - expect_is(df3, "SparkDataFrame") - expect_equal(count(df3), 3) - unlink(jsonPath2) - - hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "hivetestbl", path = hivetestDataPath) - df4 <- sql("select * from hivetestbl") - expect_is(df4, "SparkDataFrame") - expect_equal(count(df4), 3) - unlink(hivetestDataPath) - - parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) - df5 <- sql("select * from parquetest") - expect_is(df5, "SparkDataFrame") - expect_equal(count(df5), 3) - unlink(parquetDataPath) - - unsetHiveContext() + if (windows_with_hadoop()) { + setHiveContext(sc) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + df2 <- sql("select * from json") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "json2", "json", "append", path = jsonPath2) + df3 <- sql("select * from json2") + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) + unlink(jsonPath2) + + hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "hivetestbl", path = hivetestDataPath) + df4 <- sql("select * from hivetestbl") + expect_is(df4, "SparkDataFrame") + expect_equal(count(df4), 3) + unlink(hivetestDataPath) + + parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) + df5 <- sql("select * from parquetest") + expect_is(df5, "SparkDataFrame") + expect_equal(count(df5), 3) + unlink(parquetDataPath) + + unsetHiveContext() + } }) test_that("column operators", { @@ -1890,6 +1911,18 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) }) test_that("toJSON() on DataFrame", { @@ -2085,34 +2118,36 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - df <- read.df(jsonPath, "json") - # Test write.df and read.df - write.df(df, parquetPath, "parquet", mode = "overwrite") - df2 <- read.df(parquetPath, "parquet") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.parquet(df, parquetPath2) - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - suppressWarnings(saveAsParquetFile(df, parquetPath3)) - parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) - expect_is(parquetDF, "SparkDataFrame") - expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) - expect_is(parquetDF2, "SparkDataFrame") - expect_equal(count(parquetDF2), count(df) * 2) - - # Test if varargs works with variables - saveMode <- "overwrite" - mergeSchema <- "true" - parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) - - unlink(parquetPath2) - unlink(parquetPath3) - unlink(parquetPath4) + if (windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + # Test write.df and read.df + write.df(df, parquetPath, "parquet", mode = "overwrite") + df2 <- read.df(parquetPath, "parquet") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) + expect_is(parquetDF, "SparkDataFrame") + expect_equal(count(parquetDF), count(df) * 2) + parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) + expect_is(parquetDF2, "SparkDataFrame") + expect_equal(count(parquetDF2), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) + } }) test_that("read/write Parquet files - compression option/mode", { @@ -2617,7 +2652,6 @@ test_that("dapply() and dapplyCollect() on a DataFrame", { }) test_that("dapplyCollect() on DataFrame with a binary column", { - df <- data.frame(key = 1:3) df$bytes <- lapply(df$key, serialize, connection = NULL) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R similarity index 93% rename from R/pkg/inst/tests/testthat/test_streaming.R rename to R/pkg/tests/fulltests/test_streaming.R index 03b1bd3dc1f44..d691de7cd725d 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -21,10 +21,10 @@ context("Structured Streaming") # Tests for Structured Streaming functions in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsonSubDir <- file.path("sparkr-test", "json", "") -if (.Platform$OS.type == "windows") { +if (is_windows()) { # file.path removes the empty separator on Windows, adds it back jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) } @@ -53,14 +53,17 @@ test_that("read.stream, write.stream, awaitTermination, stopQuery", { q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) writeLines(mockLinesNa, jsonPathNa) awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) stopQuery(q) expect_true(awaitTermination(q, 1)) + expect_error(awaitTermination(q), NA) }) test_that("print from explain, lastProgress, status, isActive", { @@ -70,6 +73,7 @@ test_that("print from explain, lastProgress, status, isActive", { q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) @@ -92,6 +96,7 @@ test_that("Stream other format", { q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) expect_equal(queryName(q), "people3") @@ -131,7 +136,7 @@ test_that("Terminated by error", { expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), NA) - expect_error(awaitTermination(q, 1), + expect_error(awaitTermination(q, 5 * 1000), paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", " 'maxFilesPerTrigger', must be a positive integer).*")) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/tests/fulltests/test_take.R similarity index 97% rename from R/pkg/inst/tests/testthat/test_take.R rename to R/pkg/tests/fulltests/test_take.R index aaa532856c3d9..8936cc57da227 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/tests/fulltests/test_take.R @@ -30,7 +30,7 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/tests/fulltests/test_textFile.R similarity index 98% rename from R/pkg/inst/tests/testthat/test_textFile.R rename to R/pkg/tests/fulltests/test_textFile.R index 3b466066e9390..be2d2711ff88e 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/tests/fulltests/test_textFile.R @@ -18,7 +18,7 @@ context("the textFile() function") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/tests/fulltests/test_utils.R similarity index 99% rename from R/pkg/inst/tests/testthat/test_utils.R rename to R/pkg/tests/fulltests/test_utils.R index 6d006eccf665e..8cfdc9550d6a8 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/tests/fulltests/test_utils.R @@ -18,7 +18,7 @@ context("functions in utils.R") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 29812f872c784..0aefd8006caa4 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -21,14 +21,31 @@ library(SparkR) # Turn all warnings into errors options("warn" = 2) +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + # Setup global test environment # Install Spark first to set SPARK_HOME install.spark() sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") -sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") invisible(lapply(sparkRWhitelistSQLDirs, function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) + +sparkRTestMaster <- "local[1]" +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + sparkRTestMaster <- "" +} test_package("SparkR") + +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + # for testthat 1.0.2 later, change reporter from "summary" to default_reporter() + testthat:::run_tests("SparkR", + file.path(sparkRDir, "pkg", "tests", "fulltests"), + NULL, + "summary") +} diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index a6ff650c33fea..c97ba5f9a1351 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -27,6 +27,17 @@ vignette: > limitations under the License. --> +```{r setup, include=FALSE} +library(knitr) +opts_hooks$set(eval = function(options) { + # override eval to FALSE only on windows + if (.Platform$OS.type == "windows") { + options$eval = FALSE + } + options +}) +``` + ## Overview SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). @@ -46,8 +57,9 @@ We use default settings in which it runs in local mode. It auto downloads Spark ```{r, include=FALSE} install.spark() +sparkR.session(master = "local[1]") ``` -```{r, message=FALSE, results="hide"} +```{r, eval=FALSE} sparkR.session() ``` @@ -65,7 +77,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -182,7 +194,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -232,7 +244,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -364,7 +376,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -390,7 +402,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -443,20 +455,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -505,6 +517,10 @@ SparkR supports the following machine learning models and algorithms. * Alternating Least Squares (ALS) +#### Frequent Pattern Mining + +* FP-growth + #### Statistics * Kolmogorov-Smirnov Test @@ -653,6 +669,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + ```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) @@ -707,7 +724,7 @@ summary(tweedieGLM1) ``` We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: ```{r} -tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 1.2, link.power = 0.0) summary(tweedieGLM2) ``` @@ -760,7 +777,7 @@ head(predict(isoregModel, newDF)) `spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. -Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: +We use the `longley` dataset to train a gradient-boosted tree and make predictions: ```{r, warning=FALSE} df <- createDataFrame(longley) @@ -800,7 +817,7 @@ head(select(fitted, "Class", "prediction")) `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -831,9 +848,9 @@ head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20 * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. @@ -881,9 +898,9 @@ perplexity `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -891,7 +908,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -901,11 +918,42 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` +#### FP-growth + +`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. + +```{r} +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S", + "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5) +``` + +`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets. + +```{r} +head(spark.freqItemsets(fpm)) +``` + +`spark.associationRules` returns a `SparkDataFrame` with the association rules. + +```{r} +head(spark.associationRules(fpm)) +``` + +We can make predictions based on the `antecedent`. + +```{r} +head(predict(fpm, df)) +``` + #### Kolmogorov-Smirnov Test `spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). @@ -930,7 +978,7 @@ testSummary ### Model Persistence -The following example shows how to save/load an ML model by SparkR. +The following example shows how to save/load an ML model in SparkR. ```{r} t <- as.data.frame(Titanic) training <- createDataFrame(t) @@ -952,6 +1000,72 @@ unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes @@ -962,19 +1076,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. diff --git a/R/run-tests.sh b/R/run-tests.sh index 742a2c5ed76da..29764f48bd156 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/appveyor.yml b/appveyor.yml index bbb27589cad09..49e09eadee5da 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -26,6 +26,8 @@ branches: only_commits: files: + - appveyor.yml + - dev/appveyor-install-dependencies.ps1 - R/ - sql/core/src/main/scala/org/apache/spark/sql/api/r/ - core/src/main/scala/org/apache/spark/api/r/ @@ -38,16 +40,15 @@ install: # Install maven and dependencies - ps: .\dev\appveyor-install-dependencies.ps1 # Required package for R unit tests - - cmd: R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('testthat')" - - cmd: R -e "install.packages('e1071', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('e1071')" - - cmd: R -e "install.packages('survival', repos='http://cran.us.r-project.org')" - - cmd: R -e "packageVersion('survival')" + - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'testthat', 'e1071', 'survival'), repos='http://cran.us.r-project.org')" + - cmd: R -e "packageVersion('knitr'); packageVersion('rmarkdown'); packageVersion('testthat'); packageVersion('e1071'); packageVersion('survival')" build_script: - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package +environment: + NOT_CRAN: true + test_script: - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R @@ -56,4 +57,3 @@ notifications: on_build_success: false on_build_failure: false on_build_status_changed: false - diff --git a/assembly/pom.xml b/assembly/pom.xml index 9d8607d9137c6..f1433918995a6 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../pom.xml @@ -34,6 +34,19 @@ assembly none package + + scala-${scala.binary.version}/jars + + + scala-${scala.binary.version} + spark-${project.version}-yarn-shuffle.jar + ${project.parent.basedir}/common/network-yarn/target/${shuffle.jar.dir}/${shuffle.jar.basename} + + + spark + /usr/share/spark + root + 755 @@ -226,5 +239,139 @@ provided + + + deb + + + org.apache.spark + spark-network-shuffle_${scala.binary.version} + ${project.version} + + + + + + maven-antrun-plugin + + + prepare-package + + run + + + + + NOTE: Debian packaging is deprecated and is scheduled to be removed in Spark 1.4. + + + + + + + + org.codehaus.mojo + buildnumber-maven-plugin + 1.2 + + + validate + + create + + + 8 + + + + + + org.vafer + jdeb + 0.11 + + + package + + jdeb + + + ${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb + false + gzip + + + ${basedir}/target/${spark.jar.dir} + directory + + perm + ${deb.user} + ${deb.user} + ${deb.install.path}/jars + + + + ${shuffle.jar} + file + + perm + ${deb.user} + ${deb.user} + ${deb.install.path}/yarn + + + + ${basedir}/../conf + directory + + perm + ${deb.user} + ${deb.user} + ${deb.install.path}/conf + ${deb.bin.filemode} + + + + ${basedir}/../bin + directory + + perm + ${deb.user} + ${deb.user} + ${deb.install.path}/bin + ${deb.bin.filemode} + + + + ${basedir}/../sbin + directory + + perm + ${deb.user} + ${deb.user} + ${deb.install.path}/sbin + ${deb.bin.filemode} + + + + ${basedir}/../python + directory + + perm + ${deb.user} + ${deb.user} + ${deb.install.path}/python + ${deb.bin.filemode} + + + + + + + + + + + diff --git a/assembly/src/deb/control/control b/assembly/src/deb/control/control new file mode 100644 index 0000000000000..a6b4471d485f4 --- /dev/null +++ b/assembly/src/deb/control/control @@ -0,0 +1,8 @@ +Package: [[deb.pkg.name]] +Version: [[version]]-[[buildNumber]] +Section: misc +Priority: extra +Architecture: all +Maintainer: Matei Zaharia +Description: [[name]] +Distribution: development diff --git a/assembly/src/main/assembly/assembly.xml b/assembly/src/main/assembly/assembly.xml index 009d4b92f406c..6058c653c99f4 100644 --- a/assembly/src/main/assembly/assembly.xml +++ b/assembly/src/main/assembly/assembly.xml @@ -64,6 +64,15 @@ ${spark.jar.basename} + + + ${project.parent.basedir}/common/network-yarn/target/${shuffle.jar.dir} + + + + ${shuffle.jar.basename} + + diff --git a/bin/spark-class b/bin/spark-class index 77ea40cc37946..65d3b9612909a 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -72,6 +72,8 @@ build_command() { printf "%d\0" $? } +# Turn off posix mode since it does not allow process substitution +set +o posix CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 9faa7d65f83e4..a93fd2f0e54bc 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -29,7 +29,7 @@ if "x%1"=="x" ( ) rem Find Spark jars. -if exist "%SPARK_HOME%\RELEASE" ( +if exist "%SPARK_HOME%\jars" ( set SPARK_JARS_DIR="%SPARK_HOME%\jars" ) else ( set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%\jars" @@ -51,7 +51,7 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java if not "x%JAVA_HOME%"=="x" ( - set RUNNER="%JAVA_HOME%\bin\java" + set RUNNER=%JAVA_HOME%\bin\java ) else ( where /q "%RUNNER%" if ERRORLEVEL 1 ( diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8657af744c069..a972f1dab4473 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../../pom.xml @@ -90,7 +90,8 @@ org.apache.spark spark-tags_${scala.binary.version} - + test + - 1.2.1.spark2 + 1.2.1.spark2-csd-1 1.2.1 10.12.1.1 1.8.2 1.6.0 - 9.2.16.v20160414 + 9.3.11.v20160721 3.1.0 0.8.0 2.4.0 2.0.8 3.1.2 - 1.7.7 hadoop2 0.9.3 @@ -162,7 +157,7 @@ 2.11.8 2.11 1.9.13 - 2.6.5 + 2.7.3 1.1.2.6 1.1.2 1.2.0-incubating @@ -177,7 +172,7 @@ 2.22.2 2.9.3 3.5.2 - 1.3.9 + 3.0.0 0.9.3 4.5.3 1.1 @@ -247,6 +242,18 @@ + + + + ${distRepo.snapshots.id} + ${distRepo.snapshots.url} + + + ${distRepo.releases.id} + ${distRepo.releases.url} + + + @@ -2059,7 +2066,7 @@ ${project.build.directory}/surefire-reports . SparkTestSuite.txt - -ea -Xmx3g -XX:ReservedCodeCacheSize=${CodeCacheSize} + -enableassertions -Xmx3g -XX:ReservedCodeCacheSize=${CodeCacheSize} JavaConversions diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 765c92b8d3b9e..3c828d2ad83fb 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1ecb3d1958f43..499f27f00e7e0 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -371,7 +371,7 @@ querySpecification (RECORDREADER recordReader=STRING)? fromClause? (WHERE where=booleanExpression)?) - | ((kind=SELECT hint? setQuantifier? namedExpressionSeq fromClause? + | ((kind=SELECT (hints+=hint)* setQuantifier? namedExpressionSeq fromClause? | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) lateralView* (WHERE where=booleanExpression)? @@ -381,12 +381,12 @@ querySpecification ; hint - : '/*+' hintStatement '*/' + : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/' ; hintStatement : hintName=identifier - | hintName=identifier '(' parameters+=identifier (',' parameters+=identifier)* ')' + | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')' ; fromClause @@ -552,6 +552,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | CAST '(' expression AS dataType ')' #cast + | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct | FIRST '(' expression (IGNORE NULLS)? ')' #first | LAST '(' expression (IGNORE NULLS)? ')' #last | constant #constantDefault @@ -559,7 +560,7 @@ primaryExpression | qualifiedName '.' ASTERISK #star | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression - | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')' + | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 86de90984ca00..56994fafe064b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -550,7 +550,7 @@ public void copyFrom(UnsafeRow row) { */ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { if (baseObject instanceof byte[]) { - int offsetInByteArray = (int) (Platform.BYTE_ARRAY_OFFSET - baseOffset); + int offsetInByteArray = (int) (baseOffset - Platform.BYTE_ARRAY_OFFSET); out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); } else { int dataRemaining = sizeInBytes; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index bd5e2d7ecca9b..5f1032d1229da 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -37,7 +37,9 @@ public class GroupStateTimeout { * `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation * on `GroupState` for more details. */ - public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + public static GroupStateTimeout ProcessingTimeTimeout() { + return ProcessingTimeTimeout$.MODULE$; + } /** * Timeout based on event-time. The event-time timestamp for timeout can be set for each @@ -51,4 +53,5 @@ public class GroupStateTimeout { /** No timeout. */ public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } + } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 3f7cdb293e0fa..2800b3068f87b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -17,19 +17,15 @@ package org.apache.spark.sql.streaming; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** - * :: Experimental :: - * * OutputMode is used to what data will be written to a streaming sink when there is * new data available in a streaming DataFrame/Dataset. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving public class OutputMode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 86a73a319ec3f..2698faef76902 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -423,6 +423,7 @@ object JavaTypeInference { inputObject, ObjectType(keyType.getRawType), serializerFor(_, keyType), + keyNullable = true, ObjectType(valueType.getRawType), serializerFor(_, valueType), valueNullable = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 82710a2a183ab..c887634adf7bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -511,6 +511,7 @@ object ScalaReflection extends ScalaReflection { inputObject, dataTypeFor(keyType), serializerFor(_, keyType, keyPath, seenTypeSet), + keyNullable = !keyType.typeSymbol.asClass.isPrimitive, dataTypeFor(valueType), serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) @@ -836,8 +837,16 @@ trait ScalaReflection { def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { val formalTypeArgs = tpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = tpe - constructParams(tpe).map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val params = constructParams(tpe) + // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) + if (actualTypeArgs.nonEmpty) { + params.map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } else { + params.map { p => + p.name.toString -> p.typeSignature + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9816b33ae8dff..f707aa820ee57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -136,6 +136,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: @@ -150,6 +151,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -161,8 +163,6 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), - Batch("ResolveTimeZone", Once, - ResolveTimeZone), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, @@ -173,7 +173,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -201,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -243,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -280,9 +280,15 @@ class Analyzer( * We need to get all of its subsets for a given GROUPBY expression, the subsets are * represented as sequence of expressions. */ - def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match { + def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = { + // `cubeExprs0` is recursive and returns a lazy Stream. Here we call `toIndexedSeq` to + // materialize it and avoid serialization problems later on. + cubeExprs0(exprs).toIndexedSeq + } + + def cubeExprs0(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match { case x :: xs => - val initial = cubeExprs(xs) + val initial = cubeExprs0(xs) initial.map(x +: _) ++ initial case Nil => Seq(Seq.empty) @@ -315,7 +321,7 @@ class Analyzer( s"grouping columns (${groupByExprs.mkString(",")})") } case e @ Grouping(col: Expression) => - val idx = groupByExprs.indexOf(col) + val idx = groupByExprs.indexWhere(_.semanticEquals(col)) if (idx >= 0) { Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), Literal(1)), ByteType), toPrettySQL(e))() @@ -615,7 +621,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -787,7 +793,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -845,11 +851,10 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -962,11 +967,11 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case s @ Sort(orders, global, child) + case Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => @@ -983,17 +988,11 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) match { - case e if ResolveAggregateFunctions.containsAggregate(e) => - ordinal.failAnalysis( - s"GROUP BY position $index is an aggregate function, and " + - "aggregate functions are not allowed in GROUP BY") - case o => o - } + case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => + aggs(index - 1) case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + @@ -1004,6 +1003,41 @@ class Analyzer( } } + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + * This rule is expected to run after [[ResolveReferences]] applied. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + // This is a strict check though, we put this to apply the rule only if the expression is not + // resolvable by child. + private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = { + !child.output.exists(a => resolver(a.name, attrName)) + } + + private def mayResolveAttrByAggregateExprs( + exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = { + exprs.map { _.transform { + case u: UnresolvedAttribute if notResolvableByChild(u.name, child) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + }} + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(!_.resolved) => + agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) + + case gs @ GroupingSets(selectedGroups, groups, child, aggs) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + gs.copy( + selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs(_, aggs, child)), + groupByExprs = mayResolveAttrByAggregateExprs(groups, aggs, child)) + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -1013,7 +1047,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1137,7 +1171,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1161,11 +1195,21 @@ class Analyzer( // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within // the context of a Window clause. They do not need to be wrapped in an // AggregateExpression. - case wf: AggregateWindowFunction => wf + case wf: AggregateWindowFunction => + if (isDistinct) { + failAnalysis(s"${wf.prettyName} does not support the modifier DISTINCT") + } else { + wf + } // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. - case other => other + case other => + if (isDistinct) { + failAnalysis(s"${other.prettyName} does not support the modifier DISTINCT") + } else { + other + } } } } @@ -1283,7 +1327,7 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. @@ -1449,7 +1493,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1464,7 +1508,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1490,7 +1534,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1662,7 +1706,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1720,7 +1764,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2037,7 +2081,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2082,7 +2126,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2147,7 +2191,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2212,7 +2256,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2230,8 +2274,8 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { - case ArrayType(et, _) => - val expr = MapObjects(func, inputData, et, cls) transformUp { + case ArrayType(et, cn) => + val expr = MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -2298,7 +2342,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2332,7 +2376,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2347,23 +2391,6 @@ class Analyzer( } } } - - /** - * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local - * time zone. - */ - object ResolveTimeZone extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - // Casts could be added in the subquery plan through the rule TypeCoercion while coercing - // the types between the value expression and list query expression of IN expression. - // We need to subject the subquery plan through ResolveTimeZone again to setup timezone - // information for time zone aware expressions. - case e: ListQuery => e.withNewPlan(apply(e.plan)) - } - } } /** @@ -2388,7 +2415,9 @@ object EliminateUnions extends Rule[LogicalPlan] { /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or - * Window(window expressions). + * Window(window expressions). Notice that if an expression has other expression parameters which + * are not in its `children`, e.g. `RuntimeReplaceable`, the transformation for Aliases in this + * rule can't work for those parameters. */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { @@ -2403,7 +2432,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2432,6 +2461,16 @@ object CleanupAliases extends Rule[LogicalPlan] { } } +/** + * Ignore event time watermark in batch query, which is only supported in Structured Streaming. + * TODO: add this rule into analyzer rule list. + */ +object EliminateEventTimeWatermark extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case EventTimeWatermark(_, _, child) if !child.isStreaming => child + } +} + /** * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to * figure out how many windows a time column can map to, we over-estimate the number of windows and @@ -2471,7 +2510,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index da0c6b098f5ce..2e3ac3e474866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -130,12 +130,13 @@ trait CheckAnalysis extends PredicateHelper { } case s @ ScalarSubquery(query, conditions, _) => + checkAnalysis(query) + // If no correlation, the output must be exactly one column if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") - } - else if (conditions.nonEmpty) { + } else if (conditions.nonEmpty) { def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates containing exactly one aggregate expression. @@ -179,7 +180,6 @@ trait CheckAnalysis extends PredicateHelper { case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") } } - checkAnalysis(query) s case s: SubqueryExpression => @@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper { } def checkValidGroupingExprs(expr: Expression): Unit = { + if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { + failAnalysis( + "aggregate functions are not allowed in GROUP BY, but found " + expr.sql) + } + // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( @@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper { } } - aggregateExprs.foreach(checkValidAggregateExpression) groupingExprs.foreach(checkValidGroupingExprs) + aggregateExprs.foreach(checkValidAggregateExpression) case Sort(orders, _, _) => orders.foreach { order => @@ -394,7 +399,7 @@ trait CheckAnalysis extends PredicateHelper { |in operator ${operator.simpleString} """.stripMargin) - case _: Hint => + case _: UnresolvedHint => throw new IllegalStateException( "Internal error: logical hint operator should have been removed during analysis") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 9c38dd2ee4e53..fd2ac78b25dbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -80,7 +80,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions - case q => q.transformExpressions( + case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e1d83a86f99dc..f8ed89bdc270b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.lang.reflect.Modifier + import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -216,6 +218,7 @@ object FunctionRegistry { expression[UnaryMinus]("negative"), expression[Pi]("pi"), expression[Pmod]("pmod"), + expression[Fmod]("fmod"), expression[UnaryPositive]("positive"), expression[Pow]("pow"), expression[Pow]("power"), @@ -428,6 +431,8 @@ object FunctionRegistry { expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + // cast + expression[Cast]("cast"), // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), castAlias("tinyint", ByteType), @@ -455,8 +460,17 @@ object FunctionRegistry { private def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { + // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main + // constructor and contains non-parameter `child` and should not be used as function builder. + val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) { + val all = tag.runtimeClass.getConstructors + val maxNumArgs = all.map(_.getParameterCount).max + all.filterNot(_.getParameterCount == maxNumArgs) + } else { + tag.runtimeClass.getConstructors + } // See if we can find a constructor that accepts Seq[Expression] - val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption + val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]])) val builder = (expressions: Seq[Expression]) => { if (varargCtor.isDefined) { // If there is an apply method that accepts Seq[Expression], use that one. @@ -470,11 +484,8 @@ object FunctionRegistry { } else { // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) - val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { - case Success(e) => - e - case Failure(e) => - throw new AnalysisException(s"Invalid number of arguments for function $name") + val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse { + throw new AnalysisException(s"Invalid number of arguments for function $name") } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e @@ -504,7 +515,9 @@ object FunctionRegistry { } Cast(args.head, dataType) } - (name, (expressionInfo[Cast](name), builder)) + val clazz = scala.reflect.classTag[Cast].runtimeClass + val usage = "_FUNC_(expr) - Casts the value `expr` to the target data type `_FUNC_`." + (name, (new ExpressionInfo(clazz.getCanonicalName, null, name, usage, null), builder)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index c4827b81e8b63..62a3482d9fac1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -57,11 +58,11 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => - BroadcastHint(plan) + ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => - BroadcastHint(plan) + ResolvedHint(plan, HintInfo(isBroadcastable = Option(true))) - case _: BroadcastHint | _: View | _: With | _: SubqueryAlias => + case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. // For an existing broadcast hint, there is no point going down (if we do, we either // won't change the structure, or will introduce another broadcast hint that is useless. @@ -85,8 +86,19 @@ object ResolveHints { } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => - applyBroadcastHint(h.child, h.parameters.toSet) + case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + if (h.parameters.isEmpty) { + // If there is no table alias specified, turn the entire subtree into a BroadcastHint. + ResolvedHint(h.child, HintInfo(isBroadcastable = Option(true))) + } else { + // Otherwise, find within the subtree query plans that should be broadcasted. + applyBroadcastHint(h.child, h.parameters.map { + case tableName: String => tableName + case tableId: UnresolvedAttribute => tableId.name + case unsupported => throw new AnalysisException("Broadcast hint parameter should be " + + s"an identifier or string but was $unsupported (${unsupported.getClass}") + }.toSet) + } } } @@ -96,7 +108,7 @@ object ResolveHints { */ object RemoveAllHints extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint => h.child + case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index a991dd96e2828..f2df3e132629f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { val castedExpr = if (e.dataType.sameType(targetType)) { e } else { - Cast(e, targetType) + cast(e, targetType) } - castedExpr.transform { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - }.eval() + castedExpr.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 8841309939c24..de6de24350f23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ @@ -103,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase()) match { + builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index 256b18771052a..860d20f897690 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -33,7 +33,7 @@ class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] { case _ => false } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 3f76f26dbe4ec..6ab4153bac70e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -267,7 +267,7 @@ object UnsupportedOperationChecker { throwError("Limits are not supported on streaming DataFrames/Datasets") case Sort(_, _, _) if !containsCompleteData(subPlan) => - throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" + + throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + "aggregated DataFrame/Dataset in Complete output mode") case Sample(_, _, _, _, child) if child.isStreaming => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala new file mode 100644 index 0000000000000..a27aa845bf0ae --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local + * time zone. + */ +case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressions(transformTimeZoneExprs) + + def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) +} + +/** + * Mix-in trait for constructing valid [[Cast]] expressions. + */ +trait CastSupport { + /** + * Configuration used to create a valid cast expression. + */ + def conf: SQLConf + + /** + * Create a Cast expression with the session local time zone. + */ + def cast(child: Expression, dataType: DataType): Cast = { + Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bd54c257d98d..ea46dd7282401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf * This should be only done after the batch of Resolution, because the view attributes are not * completely resolved during the batch of Resolution. */ -case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { +case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver @@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") } else { - Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, + Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) } case (_, originAttr) => originAttr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 08a01e8601897..a17ca7459fca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.HadoopFileSelector import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -30,7 +32,8 @@ import org.apache.spark.sql.types.StructType * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog { +abstract class ExternalCatalog + extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { import CatalogTypes.TablePartitionSpec protected def requireDbExists(db: String): Unit = { @@ -61,9 +64,22 @@ abstract class ExternalCatalog { // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + doCreateDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + + final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + doDropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } - def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -88,11 +104,39 @@ abstract class ExternalCatalog { // Tables // -------------------------------------------------------------------------- - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit + final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(CreateTablePreEvent(db, name)) + doCreateTable(tableDefinition, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } - def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit + protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - def renameTable(db: String, oldName: String, newName: String): Unit + final def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + doDropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + protected def doDropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit + + final def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + doRenameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + protected def doRenameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -119,8 +163,6 @@ abstract class ExternalCatalog { def getTable(db: String, table: String): CatalogTable - def getTableOption(db: String, table: String): Option[CatalogTable] - def tableExists(db: String, table: String): Boolean def listTables(db: String): Seq[String] @@ -269,11 +311,30 @@ abstract class ExternalCatalog { // Functions // -------------------------------------------------------------------------- - def createFunction(db: String, funcDefinition: CatalogFunction): Unit + final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + doCreateFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } + + protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit + + final def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + doDropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + protected def doDropFunction(db: String, funcName: String): Unit - def dropFunction(db: String, funcName: String): Unit + final def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + doRenameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } - def renameFunction(db: String, oldName: String, newName: String): Unit + protected def doRenameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction @@ -281,4 +342,19 @@ abstract class ExternalCatalog { def listFunctions(db: String, pattern: String): Seq[String] + def setTableNamePreprocessor(newTableNamePreprocessor: (String) => String): Unit + + def getTableNamePreprocessor: (String) => String + + def setHadoopFileSelector(hadoopFileSelector: HadoopFileSelector): Unit + + def unsetHadoopFileSelector(): Unit + + def findHadoopFileSelector: Option[HadoopFileSelector] + + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 3ca9e6a8da5b5..57ed95250f402 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI import java.util.Locale -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException @@ -155,10 +155,37 @@ object ExternalCatalogUtils { }) inputPartitions.filter { p => - boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + boundPredicate.eval(p.toRow(partitionSchema, defaultTimeZoneId)) } } } + + /** + * Returns true if `spec1` is a partial partition spec w.r.t. `spec2`, e.g. PARTITION (a=1) is a + * partial partition spec w.r.t. PARTITION (a=1,b=2). + */ + def isPartialPartitionSpec( + spec1: TablePartitionSpec, + spec2: TablePartitionSpec): Boolean = { + spec1.forall { + case (partitionColumn, value) => spec2(partitionColumn) == value + } + } + + abstract class HadoopFileSelector { + /** + * Select files constituting a table from the given base path according to the client's custom + * algorithm. This is only applied to non-partitioned tables. + * @param tableName table name to select files for. This is the exact table name specified + * in the query, not a "preprocessed" file name returned by the user-defined + * function registered via [[ExternalCatalog.setTableNamePreprocessor]]. + * @param fs the filesystem containing the table + * @param basePath base path of the table in the filesystem + * @return a set of files, or [[None]] if the custom file selection algorithm does not apply + * to this table. + */ + def selectFiles(tableName: String, fs: FileSystem, basePath: Path): Option[Seq[Path]] + } } object CatalogUtils { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9ca1c71d1dcb1..fbcb8724f4740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,10 @@ class InMemoryCatalog( } } - override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireTableExists(db, oldName) requireTableNotExists(db, newName) val oldDesc = catalog(db).tables(oldName) @@ -312,10 +315,6 @@ class InMemoryCatalog( catalog(db).tables(table).table } - override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized { - if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table) - } - override def tableExists(db: String, table: String): Boolean = synchronized { requireDbExists(db) catalog(db).tables.contains(table) @@ -539,18 +538,6 @@ class InMemoryCatalog( } } - /** - * Returns true if `spec1` is a partial partition spec w.r.t. `spec2`, e.g. PARTITION (a=1) is a - * partial partition spec w.r.t. PARTITION (a=1,b=2). - */ - private def isPartialPartitionSpec( - spec1: TablePartitionSpec, - spec2: TablePartitionSpec): Boolean = { - spec1.forall { - case (partitionColumn, value) => spec2(partitionColumn) == value - } - } - override def listPartitionsByFilter( db: String, table: String, @@ -565,18 +552,21 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { + override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override def dropFunction(db: String, funcName: String): Unit = synchronized { + override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) @@ -599,4 +589,13 @@ class InMemoryCatalog( StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } + override def setTableNamePreprocessor(newTableNamePreprocessor: (String) => String): Unit = {} + + def getTableNamePreprocessor: (String) => String = identity + + override def setHadoopFileSelector(hadoopFileSelector: HadoopFileSelector): Unit = {} + + override def unsetHadoopFileSelector(): Unit = {} + + override def findHadoopFileSelector: Option[HadoopFileSelector] = None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3fbf83f3a38a2..f1e650762ef17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI import java.util.Locale +import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -73,7 +74,7 @@ class SessionCatalog( functionRegistry, conf, new Configuration(), - CatalystSqlParser, + new CatalystSqlParser(conf), DummyFunctionResourceLoader) } @@ -115,24 +116,46 @@ class SessionCatalog( * Format table name, taking into account case sensitivity. */ protected[this] def formatTableName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** * Format database name, taking into account case sensitivity. */ protected[this] def formatDatabaseName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } - /** - * A cache of qualified table names to table relation plans. - */ - val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { + private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { val cacheSize = conf.tableRelationCacheSize CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]() } + /** This method provides a way to get a cached plan. */ + def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = { + tableRelationCache.get(t, c) + } + + /** This method provides a way to get a cached plan if the key exists. */ + def getCachedTable(key: QualifiedTableName): LogicalPlan = { + tableRelationCache.getIfPresent(key) + } + + /** This method provides a way to cache a plan. */ + def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = { + tableRelationCache.put(t, l) + } + + /** This method provides a way to invalidate a cached plan. */ + def invalidateCachedTable(key: QualifiedTableName): Unit = { + tableRelationCache.invalidate(key) + } + + /** This method provides a way to invalidate all the cached plans. */ + def invalidateAllCachedTables(): Unit = { + tableRelationCache.invalidateAll() + } + /** * This method is used to make the given path qualified before we * store this path in the underlying external catalog. So, when a path @@ -365,9 +388,10 @@ class SessionCatalog( /** * Retrieve the metadata of an existing permanent table/view. If no database is specified, - * assume the table/view is in the current database. If the specified table/view is not found - * in the database then a [[NoSuchTableException]] is thrown. + * assume the table/view is in the current database. */ + @throws[NoSuchDatabaseException] + @throws[NoSuchTableException] def getTableMetadata(name: TableIdentifier): CatalogTable = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) @@ -376,18 +400,6 @@ class SessionCatalog( externalCatalog.getTable(db, table) } - /** - * Retrieve the metadata of an existing metastore table. - * If no database is specified, assume the table is in the current database. - * If the specified table is not found in the database then return None if it doesn't exist. - */ - def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = { - val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) - val table = formatTableName(name.table) - requireDbExists(db) - externalCatalog.getTableOption(db, table) - } - /** * Load files stored in given path into an existing metastore table. * If no database is specified, assume the table is in the current database. @@ -655,7 +667,9 @@ class SessionCatalog( SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempTables.contains(table)) { - val metadata = externalCatalog.getTable(db, table) + val tableNamePreprocessor = externalCatalog.getTableNamePreprocessor + val tableNameInMetastore = tableNamePreprocessor(table) + val metadata = externalCatalog.getTable(db, tableNameInMetastore).withTableName(table) if (metadata.tableType == CatalogTableType.VIEW) { val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text.")) // The relation is a view, so we wrap the relation by: @@ -667,12 +681,7 @@ class SessionCatalog( child = parser.parsePlan(viewText)) SubqueryAlias(table, child) } else { - val tableRelation = CatalogRelation( - metadata, - // we assume all the columns are nullable. - metadata.dataSchema.asNullable.toAttributes, - metadata.partitionSchema.asNullable.toAttributes) - SubqueryAlias(table, tableRelation) + SubqueryAlias(table, UnresolvedCatalogRelation(metadata)) } } else { SubqueryAlias(table, tempTables(table)) @@ -1105,8 +1114,9 @@ class SessionCatalog( !hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT)) } - protected def failFunctionLookup(name: String): Nothing = { - throw new NoSuchFunctionException(db = currentDb, func = name) + protected def failFunctionLookup(name: FunctionIdentifier): Nothing = { + throw new NoSuchFunctionException( + db = name.database.getOrElse(getCurrentDatabase), func = name.funcName) } /** @@ -1128,7 +1138,7 @@ class SessionCatalog( qualifiedName.database.orNull, qualifiedName.identifier) } else { - failFunctionLookup(name.funcName) + failFunctionLookup(name) } } } @@ -1158,8 +1168,8 @@ class SessionCatalog( } // If the name itself is not qualified, add the current database to it. - val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) - val qualifiedName = name.copy(database = database) + val database = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + val qualifiedName = name.copy(database = Some(database)) if (functionRegistry.functionExists(qualifiedName.unquotedString)) { // This function has been already loaded into the function registry. @@ -1172,10 +1182,10 @@ class SessionCatalog( // in the metastore). We need to first put the function in the FunctionRegistry. // TODO: why not just check whether the function exists first? val catalogFunction = try { - externalCatalog.getFunction(currentDb, name.funcName) + externalCatalog.getFunction(database, name.funcName) } catch { - case e: AnalysisException => failFunctionLookup(name.funcName) - case e: NoSuchPermanentFunctionException => failFunctionLookup(name.funcName) + case _: AnalysisException => failFunctionLookup(name) + case _: NoSuchPermanentFunctionException => failFunctionLookup(name) } loadFunctionResources(catalogFunction.resources) // Please note that qualifiedName is provided by the user. However, @@ -1251,9 +1261,10 @@ class SessionCatalog( dropTempFunction(func.funcName, ignoreIfNotExists = false) } } - tempTables.clear() + clearTempTables() globalTempViewManager.clear() functionRegistry.clear() + tableRelationCache.invalidateAll() // restore built-in functions FunctionRegistry.builtin.listFunction().foreach { f => val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala new file mode 100644 index 0000000000000..459973a13bb10 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.scheduler.SparkListenerEvent + +/** + * Event emitted by the external catalog when it is modified. Events are either fired before or + * after the modification (the event should document this). + */ +trait ExternalCatalogEvent extends SparkListenerEvent + +/** + * Listener interface for external catalog modification events. + */ +trait ExternalCatalogEventListener { + def onEvent(event: ExternalCatalogEvent): Unit +} + +/** + * Event fired when a database is create or dropped. + */ +trait DatabaseEvent extends ExternalCatalogEvent { + /** + * Database of the object that was touched. + */ + val database: String +} + +/** + * Event fired before a database is created. + */ +case class CreateDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been created. + */ +case class CreateDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired before a database is dropped. + */ +case class DropDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been dropped. + */ +case class DropDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired when a table is created, dropped or renamed. + */ +trait TableEvent extends DatabaseEvent { + /** + * Name of the table that was touched. + */ + val name: String +} + +/** + * Event fired before a table is created. + */ +case class CreateTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been created. + */ +case class CreateTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is dropped. + */ +case class DropTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been dropped. + */ +case class DropTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is renamed. + */ +case class RenameTablePreEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired after a table has been renamed. + */ +case class RenameTableEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired when a function is created, dropped or renamed. + */ +trait FunctionEvent extends DatabaseEvent { + /** + * Name of the function that was touched. + */ + val name: String +} + +/** + * Event fired before a function is created. + */ +case class CreateFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been created. + */ +case class CreateFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is dropped. + */ +case class DropFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been dropped. + */ +case class DropFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is renamed. + */ +case class RenameFunctionPreEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent + +/** + * Event fired after a function has been renamed. + */ +case class RenameFunctionEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index cc0cbba275b81..506826b755f09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -75,7 +75,7 @@ case class CatalogStorageFormat( CatalogUtils.maskCredentials(properties) match { case props if props.isEmpty => // No-op case props => - map.put("Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")) + map.put("Storage Properties", props.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")) } map } @@ -289,7 +289,8 @@ case class CatalogTable( copy(storage = CatalogStorageFormat( locationUri, inputFormat, outputFormat, serde, compressed, properties)) } - + def withTableName(newName: String): CatalogTable = + copy(identifier = identifier.copy(table = newName)) def toLinkedHashMap: mutable.LinkedHashMap[String, String] = { val map = new mutable.LinkedHashMap[String, String]() @@ -313,7 +314,7 @@ case class CatalogTable( } } - if (properties.nonEmpty) map.put("Properties", tableProperties) + if (properties.nonEmpty) map.put("Table Properties", tableProperties) stats.foreach(s => map.put("Statistics", s.simpleString)) map ++= storage.toLinkedHashMap if (tracksPartitionsInCatalog) map.put("Partition Provider", "Catalog") @@ -397,11 +398,22 @@ object CatalogTypes { type TablePartitionSpec = Map[String, String] } +/** + * A placeholder for a table relation, which will be replaced by concrete relation like + * `LogicalRelation` or `HiveTableRelation`, during analysis. + */ +case class UnresolvedCatalogRelation(tableMeta: CatalogTable) extends LeafNode { + assert(tableMeta.identifier.database.isDefined) + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = Nil +} /** - * A [[LogicalPlan]] that represents a table. + * A `LogicalPlan` that represents a hive table. + * + * TODO: remove this after we completely make hive as a data source. */ -case class CatalogRelation( +case class HiveTableRelation( tableMeta: CatalogTable, dataCols: Seq[AttributeReference], partitionCols: Seq[AttributeReference]) extends LeafNode with MultiInstanceRelation { @@ -415,7 +427,7 @@ case class CatalogRelation( def isPartitioned: Boolean = partitionCols.nonEmpty override def equals(relation: Any): Boolean = relation match { - case other: CatalogRelation => tableMeta == other.tableMeta && output == other.output + case other: HiveTableRelation => tableMeta == other.tableMeta && output == other.output case _ => false } @@ -434,15 +446,12 @@ case class CatalogRelation( )) override def computeStats(conf: SQLConf): Statistics = { - // For data source tables, we will create a `LogicalRelation` and won't call this method, for - // hive serde tables, we will always generate a statistics. - // TODO: unify the table stats generation. tableMeta.stats.map(_.toPlanStats(output)).getOrElse { throw new IllegalStateException("table stats must be specified.") } } - override def newInstance(): LogicalPlan = copy( + override def newInstance(): HiveTableRelation = copy( dataCols = dataCols.map(_.newInstance()), partitionCols = partitionCols.map(_.newInstance())) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 75bf780d41424..85d17afe20230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -168,6 +168,7 @@ package object dsl { case Seq() => UnresolvedStar(None) case target => UnresolvedStar(Option(target)) } + def namedStruct(e: Expression*): Expression = CreateNamedStruct(e) def callFunction[T, U]( func: T => U, @@ -366,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, overwrite, ifPartitionNotExists = false) def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) @@ -381,6 +382,9 @@ package object dsl { def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) + + def hint(name: String, parameters: Any*): LogicalPlan = + UnresolvedHint(name, parameters, logicalPlan) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bb1273f5c3d84..43df19ba009a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -89,6 +89,31 @@ object Cast { case _ => false } + /** + * Return true if we need to use the `timeZone` information casting `from` type to `to` type. + * The patterns matched reflect the current implementation in the Cast node. + * c.f. usage of `timeZone` in: + * * Cast.castToString + * * Cast.castToDate + * * Cast.castToTimestamp + */ + def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match { + case (StringType, TimestampType) => true + case (DateType, TimestampType) => true + case (TimestampType, StringType) => true + case (TimestampType, DateType) => true + case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType) + case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue) + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).exists { + case (fromField, toField) => + needsTimeZone(fromField.dataType, toField.dataType) + } + case _ => false + } + /** * Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int, * timestamp -> date. @@ -165,6 +190,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; + // Otherwise behave like Expression.resolved. + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) + + private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) + // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) @@ -450,15 +482,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericInternalRow(from.fields.length) buildCast[InternalRow](_, row => { + val newRow = new GenericInternalRow(from.fields.length) var i = 0 while (i < row.numFields) { newRow.update(i, if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) i += 1 } - newRow.copy() + newRow }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b847ef7bfaa97..74c4cddf2b47e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -241,6 +241,10 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { override def nullable: Boolean = child.nullable override def foldable: Boolean = child.foldable override def dataType: DataType = child.dataType + // As this expression gets replaced at optimization with its `child" expression, + // two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions + // are semantically equal. + override lazy val canonicalized: Expression = child.canonicalized } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 80c25d0b0fb7a..fffcc7c9ef53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -105,12 +105,22 @@ case class AggregateExpression( } // We compute the same thing regardless of our final result. - override lazy val canonicalized: Expression = + override lazy val canonicalized: Expression = { + val normalizedAggFunc = mode match { + // For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers, + // and the actual children of `aggregateFunction` is not used, here we normalize the expr id. + case PartialMerge | Final => aggregateFunction.transform { + case a: AttributeReference => a.withExprId(ExprId(0)) + } + case Partial | Complete => aggregateFunction + } + AggregateExpression( - aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction], mode, isDistinct, ExprId(0)) + } override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f2b252259b89d..ce9158e42ea30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -505,6 +505,59 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } +@ExpressionDescription( + usage = + "_FUNC_(expr1, expr2) - Returns the float `expr1` mod `expr2` having the same sign as `expr1`.", + extended = """ + Examples: + > SELECT _FUNC_(10.5, -3.0); + 1.5 + > SELECT _FUNC_(-10.5, -3.0); + -1.5 + """) +case class Fmod(left: Expression, right: Expression) + extends BinaryExpression with Serializable with ImplicitCastInputTypes { + + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def toString: String = s"fmod($left, $right)" + + override def dataType: DataType = DoubleType + + override def eval(input: InternalRow): Any = { + val input2 = right.eval(input) + if (input2 == null || input2 == 0) { + null + } else { + val input1 = left.eval(input) + if (input1 == null) { + null + } else { + input1.asInstanceOf[Double] % input2.asInstanceOf[Double] + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) + ev.copy(code = s""" + ${eval2.code} + boolean ${ev.isNull} = ${eval2.isNull} || ${eval2.value} == 0; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval1.code} + if (!${eval1.isNull}) { + ${ev.value} = ${eval1.value} % ${eval2.value}; + } else { + ${ev.isNull} = true; + } + } + """) + } + +} + /** * A function that returns the least value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 760ead42c762c..f8da78b5f5e3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,7 +27,10 @@ import scala.language.existentials import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} -import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler} +import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} +import org.apache.commons.lang3.exception.ExceptionUtils +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, JaninoRuntimeException, SimpleCompiler} import org.codehaus.janino.util.ClassFile import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException} @@ -899,8 +902,14 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - def compile(code: CodeAndComment): GeneratedClass = { + def compile(code: CodeAndComment): GeneratedClass = try { cache.get(code) + } catch { + // Cache.get() may wrap the original exception. See the following URL + // http://google.github.io/guava/releases/14.0/api/docs/com/google/common/cache/ + // Cache.html#get(K,%20java.util.concurrent.Callable) + case e @ (_: UncheckedExecutionException | _: ExecutionError) => + throw e.getCause } /** @@ -951,10 +960,14 @@ object CodeGenerator extends Logging { evaluator.cook("generated.java", code.body) recordCompilationStats(evaluator) } catch { - case e: Exception => + case e: JaninoRuntimeException => val msg = s"failed to compile: $e\n$formatted" logError(msg, e) - throw new Exception(msg, e) + throw new JaninoRuntimeException(msg, e) + case e: CompileException => + val msg = s"failed to compile: $e\n$formatted" + logError(msg, e) + throw new CompileException(msg, e.getLocation) } evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7e4c9089a2cb9..b358102d914bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -50,10 +50,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro fieldTypes: Seq[DataType], bufferHolder: String): String = { val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val fieldName = ctx.freshName("fieldName") - val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};" - val isNull = s"$input.isNullAt($i)" - ExprCode(code, isNull, fieldName) + val javaType = ctx.javaType(dt) + val isNullVar = ctx.freshName("isNull") + val valueVar = ctx.freshName("value") + val defaultValue = ctx.defaultValue(dt) + val readValue = ctx.getValue(input, dt, i.toString) + val code = + s""" + boolean $isNullVar = $input.isNullAt($i); + $javaType $valueVar = $isNullVar ? $defaultValue : $readValue; + """ + ExprCode(code, isNullVar, valueVar) } s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f8fe774823e5b..0ab72074b480c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * Common base class for time zone aware expressions. */ trait TimeZoneAwareExpression extends Expression { + /** The expression is only resolved when the time zone has been set. */ + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined /** the timezone ID to be used to evaluate value. */ def timeZoneId: Option[String] @@ -41,7 +43,7 @@ trait TimeZoneAwareExpression extends Expression { /** Returns a copy of this expression with the specified timeZoneId. */ def withTimeZone(timeZoneId: String): TimeZoneAwareExpression - @transient lazy val timeZone: TimeZone = TimeZone.getTimeZone(timeZoneId.get) + @transient lazy val timeZone: TimeZone = DateTimeUtils.getTimeZone(timeZoneId.get) } /** @@ -400,13 +402,15 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa } } +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(date) - Returns the week of the year of the given date.", + usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", extended = """ Examples: > SELECT _FUNC_('2008-02-20'); 8 """) +// scalastyle:on line.size.limit case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -414,7 +418,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa override def dataType: DataType = IntegerType @transient private lazy val c = { - val c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + val c = Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) c.setFirstDayOfWeek(Calendar.MONDAY) c.setMinimalDaysInFirstWeek(4) c @@ -429,9 +433,10 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") ctx.addMutableState(cal, c, s""" - $c = $cal.getInstance(java.util.TimeZone.getTimeZone("UTC")); + $c = $cal.getInstance($dtu.getTimeZone("UTC")); $c.setFirstDayOfWeek($cal.MONDAY); $c.setMinimalDaysInFirstWeek(4); """) @@ -952,8 +957,9 @@ case class FromUTCTimestamp(left: Expression, right: Expression) val tzTerm = ctx.freshName("tz") val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""") + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} @@ -1123,8 +1129,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) val tzTerm = ctx.freshName("tz") val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName - ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") - ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""") + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $dtu.getTimeZone("$tz");""") + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index df4d406b84d60..6b90354367f40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} import scala.util.parsing.combinator.RegexParsers @@ -149,7 +149,9 @@ case class GetJsonObject(json: Expression, path: Expression) if (parsed.isDefined) { try { - Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, jsonStr)) { parser => val output = new ByteArrayOutputStream() val matched = Utils.tryWithResource( jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => @@ -393,8 +395,10 @@ case class JsonTuple(children: Seq[Expression]) } try { - Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { - parser => parseRow(parser, input) + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parseRow(parser, input) } } catch { case _: JsonProcessingException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c4d47ab2084fd..1c61428c57f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -232,18 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.ceil()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } @@ -281,7 +283,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" > SELECT _FUNC_('100', 2, 10); 4 > SELECT _FUNC_(-10, 16, -10); - 16 + -16 """) case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -347,18 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.floor()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } @@ -966,7 +970,7 @@ case class Logarithm(left: Expression, right: Expression) * * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime - * @param mode rounding mode (e.g. HALF_UP, HALF_UP) + * @param mode rounding mode (e.g. HALF_UP, HALF_EVEN) * @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN") */ abstract class RoundBase(child: Expression, scale: Expression, @@ -1023,10 +1027,10 @@ abstract class RoundBase(child: Expression, scale: Expression, // not overriding since _scale is a constant int at runtime def nullSafeEval(input1: Any): Any = { - child.dataType match { - case _: DecimalType => + dataType match { + case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - decimal.toPrecision(decimal.precision, _scale, mode).orNull + decimal.toPrecision(decimal.precision, s, mode).orNull case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1055,10 +1059,10 @@ abstract class RoundBase(child: Expression, scale: Expression, override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val ce = child.genCode(ctx) - val evaluationCode = child.dataType match { - case _: DecimalType => + val evaluationCode = dataType match { + case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale}, + if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, java.math.BigDecimal.${modeStr})) { ${ev.value} = ${ce.value}; } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f446c3e4a75f6..43cef6ca952c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -451,6 +451,8 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param elementNullable When false, indicating elements in the collection are always + * non-null value. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) */ @@ -458,11 +460,12 @@ object MapObjects { function: Expression => Expression, inputData: Expression, elementType: DataType, + elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } @@ -656,18 +659,21 @@ object ExternalMapToCatalyst { inputMap: Expression, keyType: DataType, keyConverter: Expression => Expression, + keyNullable: Boolean, valueType: DataType, valueConverter: Expression => Expression, valueNullable: Boolean): ExternalMapToCatalyst = { val id = curId.getAndIncrement() val keyName = "ExternalMapToCatalyst_key" + id + val keyIsNull = "ExternalMapToCatalyst_key_isNull" + id val valueName = "ExternalMapToCatalyst_value" + id val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id ExternalMapToCatalyst( keyName, + keyIsNull, keyType, - keyConverter(LambdaVariable(keyName, "false", keyType, false)), + keyConverter(LambdaVariable(keyName, keyIsNull, keyType, keyNullable)), valueName, valueIsNull, valueType, @@ -683,6 +689,8 @@ object ExternalMapToCatalyst { * * @param key the name of the map key variable that used when iterate the map, and used as input for * the `keyConverter` + * @param keyIsNull the nullability of the map key variable that used when iterate the map, and + * used as input for the `keyConverter` * @param keyType the data type of the map key variable that used when iterate the map, and used as * input for the `keyConverter` * @param keyConverter A function that take the `key` as input, and converts it to catalyst format. @@ -698,6 +706,7 @@ object ExternalMapToCatalyst { */ case class ExternalMapToCatalyst private( key: String, + keyIsNull: String, keyType: DataType, keyConverter: Expression, value: String, @@ -726,6 +735,13 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") + val keyElementJavaType = ctx.javaType(keyType) + val valueElementJavaType = ctx.javaType(valueType) + ctx.addMutableState("boolean", keyIsNull, "") + ctx.addMutableState(keyElementJavaType, key, "") + ctx.addMutableState("boolean", valueIsNull, "") + ctx.addMutableState(valueElementJavaType, value, "") + val (defineEntries, defineKeyValue) = child.dataType match { case ObjectType(cls) if classOf[java.util.Map[_, _]].isAssignableFrom(cls) => val javaIteratorCls = classOf[java.util.Iterator[_]].getName @@ -737,8 +753,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${ctx.boxedType(keyType)}) $entry.getKey(); + $value = (${ctx.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -752,17 +768,23 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) $entry._1(); - ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${ctx.boxedType(keyType)}) $entry._1(); + $value = (${ctx.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } + val keyNullCheck = if (ctx.isPrimitiveType(keyType)) { + s"$keyIsNull = false;" + } else { + s"$keyIsNull = $key == null;" + } + val valueNullCheck = if (ctx.isPrimitiveType(valueType)) { - s"boolean $valueIsNull = false;" + s"$valueIsNull = false;" } else { - s"boolean $valueIsNull = $value == null;" + s"$valueIsNull = $value == null;" } val arrayCls = classOf[GenericArrayData].getName @@ -781,6 +803,7 @@ case class ExternalMapToCatalyst private( $defineEntries while($entries.hasNext()) { $defineKeyValue + $keyNullCheck $valueNullCheck ${genKeyConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5034566132f7a..358d0e7099484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,23 +17,26 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.immutable.TreeSet + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ object InterpretedPredicate { - def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = + def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = create(BindReferences.bindReference(expression, inputSchema)) - def create(expression: Expression): (InternalRow => Boolean) = { - (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] - } + def create(expression: Expression): InterpretedPredicate = new InterpretedPredicate(expression) } +case class InterpretedPredicate(expression: Expression) extends BasePredicate { + override def eval(r: InternalRow): Boolean = expression.eval(r).asInstanceOf[Boolean] +} /** * An [[Expression]] that returns a boolean value. @@ -162,19 +165,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. """.stripMargin) } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } case _ => - if (list.exists(l => l.dataType != value.dataType)) { - TypeCheckResult.TypeCheckFailure("Arguments must be same type") + val mismatchOpt = list.find(l => l.dataType != value.dataType) + if (mismatchOpt.isDefined) { + TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " + + s"${value.dataType} != ${mismatchOpt.get.dataType}") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName") } } } override def children: Seq[Expression] = value +: list lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal]) + private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType) override def nullable: Boolean = children.exists(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -189,10 +195,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { var hasNull = false list.foreach { e => val v = e.eval(input) - if (v == evaluatedValue) { - return true - } else if (v == null) { + if (v == null) { hasNull = true + } else if (ordering.equiv(v, evaluatedValue)) { + return true } } if (hasNull) { @@ -251,7 +257,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with override def nullable: Boolean = child.nullable || hasNull protected override def nullSafeEval(value: Any): Any = { - if (hset.contains(value)) { + if (set.contains(value)) { true } else if (hasNull) { null @@ -260,27 +266,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } - def getHSet(): Set[Any] = hset + @transient private[this] lazy val set = child.dataType match { + case _: AtomicType => hset + case _: NullType => hset + case _ => + // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows + TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset + } + + def getSet(): Set[Any] = set override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName val childGen = child.genCode(ctx) ctx.references += this - val hsetTerm = ctx.freshName("hset") - val hasNullTerm = ctx.freshName("hasNull") - ctx.addMutableState(setName, hsetTerm, - s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") - ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") + val setTerm = ctx.freshName("set") + val setNull = if (hasNull) { + s""" + |if (!${ev.value}) { + | ${ev.isNull} = true; + |} + """.stripMargin + } else { + "" + } + ctx.addMutableState(setName, setTerm, + s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();") ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; boolean ${ev.value} = false; if (!${ev.isNull}) { - ${ev.value} = $hsetTerm.contains(${childGen.value}); - if (!${ev.value} && $hasNullTerm) { - ${ev.isNull} = true; - } + ${ev.value} = $setTerm.contains(${childGen.value}); + $setNull } """) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 3fa84589e3c68..aa5a1b5448c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -86,6 +86,13 @@ abstract class StringRegexExpression extends BinaryExpression escape character, the following character is matched literally. It is invalid to escape any other character. + Since Spark 2.0, string literals are unescaped in our SQL parser. For example, in order + to match "\abc", the pattern should be "\\abc". + + When SQL config 'spark.sql.parser.escapedStringLiterals' is enabled, it fallbacks + to Spark 1.6 behavior regarding string literal parsing. For example, if the config is + enabled, the pattern to match "\abc" should be "\abc". + Examples: > SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%' true @@ -144,7 +151,31 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } @ExpressionDescription( - usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") + usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.", + extended = """ + Arguments: + str - a string expression + regexp - a string expression. The pattern string should be a Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser. + For example, to match "\abc", a regular expression for `regexp` can be "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to fallback + to the Spark 1.6 behavior regarding string literal parsing. For example, if the config is + enabled, the `regexp` that can match "\abc" is "^\abc$". + + Examples: + When spark.sql.parser.escapedStringLiterals is disabled (default). + > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\\Users.*' + true + + When spark.sql.parser.escapedStringLiterals is enabled. + > SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\Users.*' + true + + See also: + Use LIKE to match with simple string pattern. +""") case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 5598a146997ca..a5e253b61e7f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1005,7 +1005,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC """, extended = """ Examples: - > SELECT initcap('sPark sql'); + > SELECT _FUNC_('sPark sql'); Spark Sql """) case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index e0ed03a68981a..025a388aacaa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.InputStream +import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text @@ -33,7 +33,10 @@ private[sql] object CreateJacksonParser extends Serializable { val bb = record.getByteBuffer assert(bb.hasArray) - jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + val bain = new ByteArrayInputStream( + bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 23ba5ed4d50dc..1fd680ab64b5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -70,7 +70,7 @@ private[sql] class JSONOptions( val columnNameOfCorruptRecord = parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) - val timeZone: TimeZone = TimeZone.getTimeZone( + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. @@ -81,7 +81,7 @@ private[sql] class JSONOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) - val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index ff6c93ae9815c..4ed6728994193 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream -import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -126,16 +125,11 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase(Locale.ROOT) - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toFloat - } else { - throw new RuntimeException(s"Cannot parse $value as FloatType.") + parser.getText match { + case "NaN" => Float.NaN + case "Infinity" => Float.PositiveInfinity + case "-Infinity" => Float.NegativeInfinity + case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") } } @@ -146,16 +140,11 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase(Locale.ROOT) - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toDouble - } else { - throw new RuntimeException(s"Cannot parse $value as DoubleType.") + parser.getText match { + case "NaN" => Double.NaN + case "Infinity" => Double.PositiveInfinity + case "-Infinity" => Double.NegativeInfinity + case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala index 3b23c6cd2816f..134d16e981a15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala @@ -44,7 +44,9 @@ object JacksonUtils { case at: ArrayType => verifyType(name, at.elementType) - case mt: MapType => verifyType(name, mt.keyType) + // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when + // generating JSON, so we only care if the values are valid for JSON. + case mt: MapType => verifyType(name, mt.valueType) case udt: UserDefinedType[_] => verifyType(name, udt.sqlType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d221b0611a892..71e03ee829710 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -113,17 +113,18 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) SimplifyCreateArrayOps, SimplifyCreateMapOps) ++ extendedOperatorOptimizationRules: _*) :: - Batch("Check Cartesian Products", Once, - CheckCartesianProducts(conf)) :: Batch("Join Reorder", Once, CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: - Batch("Typed Filter Optimization", fixedPoint, + Batch("Object Expressions Optimization", fixedPoint, + EliminateMapObjects, CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :: + Batch("Check Cartesian Products", Once, + CheckCartesianProducts(conf)) :: Batch("OptimizeCodegen", Once, OptimizeCodegen(conf)) :: Batch("RewriteSubquery", Once, @@ -440,8 +441,7 @@ object ColumnPruning extends Rule[LogicalPlan] { g.copy(child = prunedChild(g.child, g.references)) // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) - if g.join && !g.outer && p.references.subsetOf(g.generatedSet) => + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) // Eliminate unneeded attributes from right side of a Left Existence Join. @@ -861,7 +861,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). case _: AppendColumns => true - case _: BroadcastHint => true + case _: ResolvedHint => true case _: Distinct => true case _: Generate => true case _: Pivot => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 7400a01918c52..987cd7434b459 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Join with one or two empty children (including Intersect/Except). * 2. Unary-node Logical Plans * - Project/Filter/Sample/Join/Limit/Repartition with all empty children. - * - Aggregate with all empty children and without AggregateFunction expressions like COUNT. + * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { @@ -39,10 +38,6 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { case _ => false } - private def containsAggregateExpression(e: Expression): Boolean = { - e.collectFirst { case _: AggregateFunction => () }.isDefined - } - private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty) def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -68,8 +63,13 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { case _: LocalLimit => empty(p) case _: Repartition => empty(p) case _: RepartitionByExpression => empty(p) - // AggregateExpressions like COUNT(*) return their results like 0. - case Aggregate(_, ae, _) if !ae.exists(containsAggregateExpression) => empty(p) + // An aggregate with non-empty group expression will return one output row per group when the + // input to the aggregate is not empty. If the input to the aggregate is empty then all groups + // will be empty and thus the output will be empty. + // + // If the grouping expressions are empty, however, then the aggregate will always produce a + // single output row and thus we cannot propagate the EmptyRelation. + case Aggregate(ge, _, _) if ge.nonEmpty => empty(p) // Generators like Hive-style UDTF may return their records within `close`. case Generate(_: Explode, _, _, _, _, _) => empty(p) case _ => p diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8445ee06bd89b..f2334830f8d88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -153,6 +154,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral + case a And b if Not(a).semanticEquals(b) => FalseLiteral + case a Or b if Not(a).semanticEquals(b) => TrueLiteral + case a And b if a.semanticEquals(Not(b)) => FalseLiteral + case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a @@ -320,22 +326,27 @@ object LikeSimplification extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(input, Literal(pattern, StringType)) => - pattern.toString match { - case startsWith(prefix) if !prefix.endsWith("\\") => - StartsWith(input, Literal(prefix)) - case endsWith(postfix) => - EndsWith(input, Literal(postfix)) - // 'a%a' pattern is basically same with 'a%' && '%a'. - // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. - case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => - And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), - And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) - case contains(infix) if !infix.endsWith("\\") => - Contains(input, Literal(infix)) - case equalTo(str) => - EqualTo(input, Literal(str)) - case _ => - Like(input, Literal.create(pattern, StringType)) + if (pattern == null) { + // If pattern is null, return null value directly, since "col like null" == null. + Literal(null, BooleanType) + } else { + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.length + postfix.length)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) + case _ => + Like(input, Literal.create(pattern, StringType)) + } } } } @@ -368,6 +379,8 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { case EqualNullSafe(Literal(null, _), r) => IsNull(r) case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + case AssertNotNull(c, _) if !c.nullable => c + // For Coalesce, remove null literals. case e @ Coalesce(children) => val newChildren = children.filterNot(isNullLiteral) @@ -469,7 +482,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { case _: Distinct => true case _: AppendColumns => true case _: AppendColumnsWithObject => true - case _: BroadcastHint => true + case _: ResolvedHint => true case _: RepartitionByExpression => true case _: Repartition => true case _: Sort => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 89e1dc9e322e0..af0837e36e8ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import java.util.TimeZone - import scala.collection.mutable import org.apache.spark.sql.catalyst.catalog.SessionCatalog @@ -55,7 +53,7 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { case CurrentDate(Some(timeZoneId)) => currentDates.getOrElseUpdate(timeZoneId, { Literal.create( - DateTimeUtils.millisToDays(timestamp / 1000L, TimeZone.getTimeZone(timeZoneId)), + DateTimeUtils.millisToDays(timestamp / 1000L, DateTimeUtils.getTimeZone(timeZoneId)), DateType) }) case CurrentTimestamp() => currentTime diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index c3ab58744953d..2fe3039774423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -134,8 +134,8 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) - val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) - val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + lazy val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + lazy val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) join.joinType match { case RightOuter if leftHasNonNullPredicate => Inner diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 257dbfac8c3e8..8cdc6425bcad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -96,3 +97,15 @@ object CombineTypedFilters extends Rule[LogicalPlan] { } } } + +/** + * Removes MapObjects when the following conditions are satisfied + * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output + * are primitive types with non-nullable + * 2. no custom collection class specified representation of data item. + */ +object EliminateMapObjects extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e1db1ef5b8695..d1c9332bee18b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler @@ -44,9 +45,11 @@ import org.apache.spark.util.random.RandomSampler * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { +class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ + def this() = this(new SQLConf()) + protected def typedVisit[T](ctx: ParseTree): T = { ctx.accept(this).asInstanceOf[T] } @@ -215,7 +218,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ protected def visitNonOptionalPartitionSpec( ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { - visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } } /** @@ -400,7 +406,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val withWindow = withDistinct.optionalMap(windows)(withWindows) // Hint - withWindow.optionalMap(hint)(withHints) + hints.asScala.foldRight(withWindow)(withHints) } } @@ -526,13 +532,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Add a [[Hint]] to a logical plan. + * Add [[UnresolvedHint]]s to a logical plan. */ private def withHints( ctx: HintContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - val stmt = ctx.hintStatement - Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query) + var plan = query + ctx.hintStatements.asScala.reverse.foreach { case stmt => + plan = UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(expression), plan) + } + plan } /** @@ -1024,6 +1033,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) } + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.argument.asScala.map(expression)) + } + /** * Create a [[First]] expression. */ @@ -1047,7 +1063,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.namedExpression().asScala.map(expression) match { + val arguments = ctx.argument.asScala.map(expression) match { case Seq(UnresolvedStar(None)) if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). @@ -1406,7 +1422,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Special characters can be escaped by using Hive/C-style escaping. */ private def createString(ctx: StringLiteralContext): String = { - ctx.STRING().asScala.map(string).mkString + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(stringWithoutUnescape).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } } /** @@ -1488,8 +1508,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("decimal", precision :: scale :: Nil) => DecimalType(precision.getText.toInt, scale.getText.toInt) case (dt, params) => - throw new ParseException( - s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 80ab75cc17fab..8e2e973485e1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} /** @@ -34,8 +35,7 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class AbstractSqlParser extends ParserInterface with Logging { /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. + override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) } @@ -50,8 +50,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { } /** Creates FunctionIdentifier for a given SQL string. */ - def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } } /** @@ -120,8 +122,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { /** * Concrete SQL parser for Catalyst-only SQL statements. */ +class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser { + val astBuilder = new AstBuilder(conf) +} + +/** For test-only. */ object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder + val astBuilder = new AstBuilder(new SQLConf()) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index db3598bde04d3..75240d2196222 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,30 +17,51 @@ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** * Interface for a parser. */ +@DeveloperApi trait ParserInterface { - /** Creates LogicalPlan for a given SQL string. */ + /** + * Parse a string to a [[LogicalPlan]]. + */ + @throws[ParseException]("Text cannot be parsed to a LogicalPlan") def parsePlan(sqlText: String): LogicalPlan - /** Creates Expression for a given SQL string. */ + /** + * Parse a string to an [[Expression]]. + */ + @throws[ParseException]("Text cannot be parsed to an Expression") def parseExpression(sqlText: String): Expression - /** Creates TableIdentifier for a given SQL string. */ + /** + * Parse a string to a [[TableIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a TableIdentifier") def parseTableIdentifier(sqlText: String): TableIdentifier - /** Creates FunctionIdentifier for a given SQL string. */ + /** + * Parse a string to a [[FunctionIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") def parseFunctionIdentifier(sqlText: String): FunctionIdentifier /** - * Creates StructType for a given SQL string, which is a comma separated list of field - * definitions which will preserve the correct Hive metadata. + * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list + * of field definitions which will preserve the correct Hive metadata. */ + @throws[ParseException]("Text cannot be parsed to a schema") def parseTableSchema(sqlText: String): StructType + + /** + * Parse a string to a [[DataType]]. + */ + @throws[ParseException]("Text cannot be parsed to a DataType") + def parseDataType(sqlText: String): DataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 6fbc33fad735c..77fdaa8255aa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -68,6 +68,12 @@ object ParserUtils { /** Convert a string node into a string. */ def string(node: TerminalNode): String = unescapeSQLString(node.getText) + /** Convert a string node into a string without unescaping. */ + def stringWithoutUnescape(node: TerminalNode): String = { + // STRING parser rule forces that the input always has quotes at the starting and ending. + node.getText.slice(1, node.getText.size - 1) + } + /** Get the origin (line and position) of the token. */ def position(token: Token): Origin = { val opt = Option(token) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index d39b0ef7e1d8a..ef925f92ecc7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -65,8 +65,8 @@ object PhysicalOperation extends PredicateHelper { val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) - case BroadcastHint(child) => - collectProjectsAndFilters(child) + case h: ResolvedHint => + collectProjectsAndFilters(h.child) case other => (None, Nil, other, Map.empty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2fb65bd435507..d3f822bf7eb0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -81,11 +81,12 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case _ => Seq.empty[Attribute] } - // Collect aliases from expressions, so we may avoid producing recursive constraints. - private lazy val aliasMap = AttributeMap( - (expressions ++ children.flatMap(_.expressions)).collect { + // Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so + // we may avoid producing recursive constraints. + private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( + expressions.collect { case a: Alias => (a.toAttribute, a.child) - }) + } ++ children.flatMap(_.aliasMap)) /** * Infers an additional set of constraints from a given set of equality constraints. @@ -286,7 +287,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpression(e) - case Some(e: Expression) => Some(transformExpression(e)) + case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) @@ -320,7 +321,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT productIterator.flatMap { case e: Expression => e :: Nil - case Some(e: Expression) => e :: Nil + case s: Some[_] => seqToExpressions(s.toSeq) case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq @@ -423,7 +424,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT lazy val allAttributes: AttributeSeq = children.flatMap(_.output) } -object QueryPlan { +object QueryPlan extends PredicateHelper { /** * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we @@ -442,4 +443,17 @@ object QueryPlan { } }.canonicalized.asInstanceOf[T] } + + /** + * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. + * Then returns a new sequence of predicates by splitting the conjunctive predicate. + */ + def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { + if (predicates.nonEmpty) { + val normalized = normalizeExprId(predicates.reduce(And), output) + splitConjunctivePredicates(normalized) + } else { + Nil + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6bdcf490ca5c8..2ebb2ff323c6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -347,7 +347,7 @@ abstract class UnaryNode extends LogicalPlan { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable) + Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 3d4efef953a64..a64562b5dbd93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -46,13 +46,13 @@ import org.apache.spark.util.Utils * defaults to the product of children's `sizeInBytes`. * @param rowCount Estimated number of rows. * @param attributeStats Statistics for Attributes. - * @param isBroadcastable If true, output is small enough to be used in a broadcast join. + * @param hints Query hints. */ case class Statistics( sizeInBytes: BigInt, rowCount: Option[BigInt] = None, attributeStats: AttributeMap[ColumnStat] = AttributeMap(Nil), - isBroadcastable: Boolean = false) { + hints: HintInfo = HintInfo()) { override def toString: String = "Statistics(" + simpleString + ")" @@ -65,7 +65,7 @@ case class Statistics( } else { "" }, - s"isBroadcastable=$isBroadcastable" + s"hints=$hints" ).filter(_.nonEmpty).mkString(", ") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba851..2eee94364d84e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -83,7 +83,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param qualifier Qualifier for the attributes of generator(UDTF) * @param generatorOutput The output schema of the Generator. * @param child Children logical plan node @@ -195,9 +195,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation val leftSize = left.stats(conf).sizeInBytes val rightSize = right.stats(conf).sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize - val isBroadcastable = left.stats(conf).isBroadcastable || right.stats(conf).isBroadcastable - - Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable) + Statistics( + sizeInBytes = sizeInBytes, + hints = left.stats(conf).hints.resetForJoin()) } } @@ -364,7 +364,8 @@ case class Join( case _ => // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - super.computeStats(conf).copy(isBroadcastable = false) + val stats = super.computeStats(conf) + stats.copy(hints = stats.hints.resetForJoin()) } if (conf.cboEnabled) { @@ -375,26 +376,6 @@ case class Join( } } -/** - * A hint for the optimizer that we should broadcast the `child` if used in a join operator. - */ -case class BroadcastHint(child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - // set isBroadcastable to true so the child will be broadcasted - override def computeStats(conf: SQLConf): Statistics = - child.stats(conf).copy(isBroadcastable = true) -} - -/** - * A general hint for the child. This node will be eliminated post analysis. - * A pair of (name, parameters). - */ -case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode { - override lazy val resolved: Boolean = false - override def output: Seq[Attribute] = child.output -} - /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. @@ -410,17 +391,20 @@ case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) exten * would have Map('a' -> Some('1'), 'b' -> None). * @param query the logical plan representing data to write to. * @param overwrite overwrite existing table or partitions. - * @param ifNotExists If true, only write if the table or partition does not exist. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. */ case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifPartitionNotExists: Boolean) extends LogicalPlan { - assert(overwrite || !ifNotExists) - assert(partition.values.forall(_.nonEmpty) || !ifNotExists) + // IF NOT EXISTS is only valid in INSERT OVERWRITE + assert(overwrite || !ifPartitionNotExists) + // IF NOT EXISTS is only valid in static partitions + assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists) // We don't want `table` in children as sometimes we don't want to transform it. override def children: Seq[LogicalPlan] = query :: Nil @@ -577,7 +561,7 @@ case class Aggregate( Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), rowCount = Some(1), - isBroadcastable = child.stats(conf).isBroadcastable) + hints = child.stats(conf).hints) } else { super.computeStats(conf) } @@ -704,7 +688,7 @@ case class Expand( * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer * * @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should - * exists in groupByExprs. + * exist in groupByExprs. * @param groupByExprs The Group By expressions candidates. * @param child Child operator * @param aggregations The Aggregation expressions, those non selected group by expressions @@ -766,7 +750,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), rowCount = Some(rowCount), - isBroadcastable = childStats.isBroadcastable) + hints = childStats.hints) } } @@ -787,7 +771,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo Statistics( sizeInBytes = 1, rowCount = Some(0), - isBroadcastable = childStats.isBroadcastable) + hints = childStats.hints) } else { // The output row count of LocalLimit should be the sum of row counts from each partition. // However, since the number of partitions is not available here, we just use statistics of @@ -838,7 +822,7 @@ case class Sample( } val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) // Don't propagate column stats, because we don't know the distribution after a sample operation - Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable) + Statistics(sizeInBytes, sampledRowCount, hints = childStats.hints) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala new file mode 100644 index 0000000000000..d16fae56b3d4a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf + +/** + * A general hint for the child that is not yet resolved. This node is generated by the parser and + * should be removed This node will be eliminated post analysis. + * @param name the name of the hint + * @param parameters the parameters of the hint + * @param child the [[LogicalPlan]] on which this hint applies + */ +case class UnresolvedHint(name: String, parameters: Seq[Any], child: LogicalPlan) + extends UnaryNode { + + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = child.output +} + +/** + * A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]]. + */ +case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) + extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + override lazy val canonicalized: LogicalPlan = child.canonicalized + + override def computeStats(conf: SQLConf): Statistics = { + val stats = child.stats(conf) + stats.copy(hints = hints) + } +} + + +case class HintInfo( + isBroadcastable: Option[Boolean] = None) { + + /** Must be called when computing stats for a join operator to reset hints. */ + def resetForJoin(): HintInfo = copy( + isBroadcastable = None + ) + + override def toString: String = { + if (productIterator.forall(_.asInstanceOf[Option[_]].isEmpty)) { + "none" + } else { + isBroadcastable.map(x => s"isBroadcastable=$x").getOrElse("") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 48b5fbb03ef1e..a0c23198451a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -56,7 +56,7 @@ object AggregateEstimation { sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, - isBroadcastable = childStats.isBroadcastable)) + hints = childStats.hints)) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index f1aff62cb6af0..e5fcdf9039be9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -43,6 +43,18 @@ object EstimationUtils { avgLen = dataType.defaultSize, maxLen = dataType.defaultSize) } + /** + * Updates (scales down) the number of distinct values if the number of rows decreases after + * some operation (such as filter, join). Otherwise keep it unchanged. + */ + def updateNdv(oldNumRows: BigInt, newNumRows: BigInt, oldNdv: BigInt): BigInt = { + if (newNumRows < oldNumRows) { + ceil(BigDecimal(oldNdv) * BigDecimal(newNumRows) / BigDecimal(oldNumRows)) + } else { + oldNdv + } + } + def ceil(bigDecimal: BigDecimal): BigInt = bigDecimal.setScale(0, RoundingMode.CEILING).toBigInt() /** Get column stats for output attributes. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 4b6b3b14d9ac8..df190867189ec 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -32,14 +32,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging private val childStats = plan.child.stats(catalystConf) - /** - * We will update the corresponding ColumnStats for a column after we apply a predicate condition. - * For example, column c has [min, max] value as [0, 100]. In a range condition such as - * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we - * evaluate the first condition c > 40. We need to set the column's [min, max] value to [40, 50] - * after we evaluate the second condition c <= 50. - */ - private val colStatsMap = new ColumnStatsMap + private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) /** * Returns an option of Statistics for a Filter logical plan node. @@ -53,24 +46,19 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging def estimate: Option[Statistics] = { if (childStats.rowCount.isEmpty) return None - // Save a mutable copy of colStats so that we can later change it recursively. - colStatsMap.setInitValues(childStats.attributeStats) - // Estimate selectivity of this filter predicate, and update column stats if needed. // For not-supported condition, set filter selectivity to a conservative estimate 100% - val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0) + val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(BigDecimal(1.0)) - val newColStats = if (filterSelectivity == 0) { + val filteredRowCount: BigInt = ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) + val newColStats = if (filteredRowCount == 0) { // The output is empty, we don't need to keep column stats. AttributeMap[ColumnStat](Nil) } else { - colStatsMap.toColumnStats + colStatsMap.outputColumnStats(rowsBeforeFilter = childStats.rowCount.get, + rowsAfterFilter = filteredRowCount) } - - val filteredRowCount: BigInt = - EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) - val filteredSizeInBytes: BigInt = - EstimationUtils.getOutputSize(plan.output, filteredRowCount, newColStats) + val filteredSizeInBytes: BigInt = getOutputSize(plan.output, filteredRowCount, newColStats) Some(childStats.copy(sizeInBytes = filteredSizeInBytes, rowCount = Some(filteredRowCount), attributeStats = newColStats)) @@ -92,16 +80,17 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ - def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { + def calculateFilterSelectivity(condition: Expression, update: Boolean = true) + : Option[BigDecimal] = { condition match { case And(cond1, cond2) => - val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) - val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(BigDecimal(1.0)) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(BigDecimal(1.0)) Some(percent1 * percent2) case Or(cond1, cond2) => - val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) - val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(BigDecimal(1.0)) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(BigDecimal(1.0)) Some(percent1 + percent2 - (percent1 * percent2)) // Not-operator pushdown @@ -143,7 +132,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ - def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { + def calculateSingleCondition(condition: Expression, update: Boolean): Option[BigDecimal] = { condition match { case l: Literal => evaluateLiteral(l) @@ -237,7 +226,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging def evaluateNullCheck( attr: Attribute, isNull: Boolean, - update: Boolean): Option[Double] = { + update: Boolean): Option[BigDecimal] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -256,7 +245,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } else { colStat.copy(nullCount = 0) } - colStatsMap(attr) = newStats + colStatsMap.update(attr, newStats) } val percent = if (isNull) { @@ -265,7 +254,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging 1.0 - nullPercent } - Some(percent.toDouble) + Some(percent) } /** @@ -283,7 +272,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging op: BinaryComparison, attr: Attribute, literal: Literal, - update: Boolean): Option[Double] = { + update: Boolean): Option[BigDecimal] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -317,7 +306,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging def evaluateEquality( attr: Attribute, literal: Literal, - update: Boolean): Option[Double] = { + update: Boolean): Option[BigDecimal] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -341,10 +330,10 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging colStat.copy(distinctCount = 1, min = Some(literal.value), max = Some(literal.value), nullCount = 0) } - colStatsMap(attr) = newStats + colStatsMap.update(attr, newStats) } - Some((1.0 / BigDecimal(ndv)).toDouble) + Some(1.0 / BigDecimal(ndv)) } else { Some(0.0) } @@ -361,7 +350,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging * @param literal a literal value (or constant) * @return an optional double value to show the percentage of rows meeting a given condition */ - def evaluateLiteral(literal: Literal): Option[Double] = { + def evaluateLiteral(literal: Literal): Option[BigDecimal] = { literal match { case Literal(null, _) => Some(0.0) case FalseLiteral => Some(0.0) @@ -386,7 +375,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging def evaluateInSet( attr: Attribute, hSet: Set[Any], - update: Boolean): Option[Double] = { + update: Boolean): Option[BigDecimal] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -417,7 +406,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging if (update) { val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), max = Some(newMax), nullCount = 0) - colStatsMap(attr) = newStats + colStatsMap.update(attr, newStats) } // We assume the whole set since there is no min/max information for String/Binary type @@ -425,13 +414,13 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging newNdv = ndv.min(BigInt(hSet.size)) if (update) { val newStats = colStat.copy(distinctCount = newNdv, nullCount = 0) - colStatsMap(attr) = newStats + colStatsMap.update(attr, newStats) } } // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble)) + Some((BigDecimal(newNdv) / BigDecimal(ndv)).min(1.0)) } /** @@ -449,7 +438,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging op: BinaryComparison, attr: Attribute, literal: Literal, - update: Boolean): Option[Double] = { + update: Boolean): Option[BigDecimal] = { val colStat = colStatsMap(attr) val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] @@ -518,7 +507,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min - var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + var newNdv = ceil(ndv * percent) if (newNdv < 1) newNdv = 1 op match { @@ -532,11 +521,11 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val newStats = colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) - colStatsMap(attr) = newStats + colStatsMap.update(attr, newStats) } } - Some(percent.toDouble) + Some(percent) } /** @@ -557,7 +546,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging op: BinaryComparison, attrLeft: Attribute, attrRight: Attribute, - update: Boolean): Option[Double] = { + update: Boolean): Option[BigDecimal] = { if (!colStatsMap.contains(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) @@ -654,10 +643,10 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging // Need to adjust new min/max after the filter condition is applied val ndvLeft = BigDecimal(colStatLeft.distinctCount) - var newNdvLeft = (ndvLeft * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + var newNdvLeft = ceil(ndvLeft * percent) if (newNdvLeft < 1) newNdvLeft = 1 val ndvRight = BigDecimal(colStatRight.distinctCount) - var newNdvRight = (ndvRight * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + var newNdvRight = ceil(ndvRight * percent) if (newNdvRight < 1) newNdvRight = 1 var newMaxLeft = colStatLeft.max @@ -750,24 +739,57 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } } - Some(percent.toDouble) + Some(percent) } } -class ColumnStatsMap { - private val baseMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty +/** + * This class contains the original column stats from child, and maintains the updated column stats. + * We will update the corresponding ColumnStats for a column after we apply a predicate condition. + * For example, column c has [min, max] value as [0, 100]. In a range condition such as + * (c > 40 AND c <= 50), we need to set the column's [min, max] value to [40, 100] after we + * evaluate the first condition c > 40. We also need to set the column's [min, max] value to + * [40, 50] after we evaluate the second condition c <= 50. + * + * @param originalMap Original column stats from child. + */ +case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { - def setInitValues(colStats: AttributeMap[ColumnStat]): Unit = { - baseMap.clear() - baseMap ++= colStats.baseMap - } + /** This map maintains the latest column stats. */ + private val updatedMap: mutable.Map[ExprId, (Attribute, ColumnStat)] = mutable.HashMap.empty - def contains(a: Attribute): Boolean = baseMap.contains(a.exprId) + def contains(a: Attribute): Boolean = updatedMap.contains(a.exprId) || originalMap.contains(a) - def apply(a: Attribute): ColumnStat = baseMap(a.exprId)._2 + /** + * Gets column stat for the given attribute. Prefer the column stat in updatedMap than that in + * originalMap, because updatedMap has the latest (updated) column stats. + */ + def apply(a: Attribute): ColumnStat = { + if (updatedMap.contains(a.exprId)) { + updatedMap(a.exprId)._2 + } else { + originalMap(a) + } + } - def update(a: Attribute, stats: ColumnStat): Unit = baseMap.update(a.exprId, a -> stats) + /** Updates column stats in updatedMap. */ + def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats) - def toColumnStats: AttributeMap[ColumnStat] = AttributeMap(baseMap.values.toSeq) + /** + * Collects updated column stats, and scales down ndv for other column stats if the number of rows + * decreases after this Filter operator. + */ + def outputColumnStats(rowsBeforeFilter: BigInt, rowsAfterFilter: BigInt) + : AttributeMap[ColumnStat] = { + val newColumnStats = originalMap.map { case (attr, oriColStat) => + // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows + // decreases; otherwise keep it unchanged. + val newNdv = EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, + newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount) + val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) + attr -> colStat.copy(distinctCount = newNdv) + } + AttributeMap(newColumnStats.toSeq) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 3245a73c8a2eb..8ef905c45d50d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -217,32 +217,17 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { if (joinKeyStats.contains(a)) { outputAttrStats += a -> joinKeyStats(a) } else { - val leftRatio = if (leftRows != 0) { - BigDecimal(outputRows) / BigDecimal(leftRows) - } else { - BigDecimal(0) - } - val rightRatio = if (rightRows != 0) { - BigDecimal(outputRows) / BigDecimal(rightRows) - } else { - BigDecimal(0) - } val oldColStat = oldAttrStats(a) val oldNdv = oldColStat.distinctCount - // We only change (scale down) the number of distinct values if the number of rows - // decreases after join, because join won't produce new values even if the number of - // rows increases. - val newNdv = if (join.left.outputSet.contains(a) && leftRatio < 1) { - ceil(BigDecimal(oldNdv) * leftRatio) - } else if (join.right.outputSet.contains(a) && rightRatio < 1) { - ceil(BigDecimal(oldNdv) * rightRatio) + val newNdv = if (join.left.outputSet.contains(a)) { + updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv) } else { - oldNdv + updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv) } + val newColStat = oldColStat.copy(distinctCount = newNdv) // TODO: support nullCount updates for specific outer joins - outputAttrStats += a -> oldColStat.copy(distinctCount = newNdv) + outputAttrStats += a -> newColStat } - } outputAttrStats } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6fc828f63f152..85b368c862630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -122,7 +122,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { logDebug( s""" |=== Result of Batch ${batch.name} === - |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} + |${sideBySide(batchStartPlan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { logTrace(s"Batch ${batch.name} has no effect.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cc4c0835954ba..ae5c513eb040b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -340,8 +340,18 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = f(arg1.asInstanceOf[BaseType]) - val newChild2 = f(arg2.asInstanceOf[BaseType]) + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { changed = true (newChild1, newChild2) @@ -444,6 +454,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case None => Nil case Some(null) => Nil case Some(any) => any :: Nil + case table: CatalogTable => + table.storage.serde match { + case Some(serde) => table.identifier :: serde :: Nil + case _ => table.identifier :: Nil + } case other => other :: Nil }.mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index eb6aad5b2d2bb..02cfa6e1b8afd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{Calendar, Locale, TimeZone} +import java.util.concurrent.ConcurrentHashMap +import java.util.function.{Function => JFunction} import javax.xml.bind.DatatypeConverter import scala.annotation.tailrec @@ -30,7 +32,7 @@ import org.apache.spark.unsafe.types.UTF8String * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of * dates since the Unix epoch (1970-01-01). Timestamps are exposed externally as java.sql.Timestamp - * and are stored internally as longs, which are capable of storing timestamps with 100 nanosecond + * and are stored internally as longs, which are capable of storing timestamps with microsecond * precision. */ object DateTimeUtils { @@ -98,6 +100,15 @@ object DateTimeUtils { sdf } + private val computedTimeZones = new ConcurrentHashMap[String, TimeZone] + private val computeTimeZone = new JFunction[String, TimeZone] { + override def apply(timeZoneId: String): TimeZone = TimeZone.getTimeZone(timeZoneId) + } + + def getTimeZone(timeZoneId: String): TimeZone = { + computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone) + } + def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = { val sdf = new SimpleDateFormat(formatString, Locale.US) sdf.setTimeZone(timeZone) @@ -388,13 +399,14 @@ object DateTimeUtils { digitsMilli += 1 } - if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { - return None + // We are truncating the nanosecond part, which results in loss of precision + while (digitsMilli > 6) { + segments(6) /= 10 + digitsMilli -= 1 } - // Instead of return None, we truncate the fractional seconds to prevent inserting NULL - if (segments(6) > 999999) { - segments(6) = segments(6).toString.take(6).toInt + if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { + return None } if (segments(3) < 0 || segments(3) > 23 || segments(4) < 0 || segments(4) > 59 || @@ -407,7 +419,7 @@ object DateTimeUtils { Calendar.getInstance(timeZone) } else { Calendar.getInstance( - TimeZone.getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d")) + getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d")) } c.set(Calendar.MILLISECOND, 0) @@ -592,7 +604,14 @@ object DateTimeUtils { */ private[this] def getYearAndDayInYear(daysSince1970: SQLDate): (Int, Int) = { // add the difference (in days) between 1.1.1970 and the artificial year 0 (-17999) - val daysNormalized = daysSince1970 + toYearZero + var daysSince1970Tmp = daysSince1970 + // Since Julian calendar was replaced with the Gregorian calendar, + // the 10 days after Oct. 4 were skipped. + // (1582-10-04) -141428 days since 1970-01-01 + if (daysSince1970 <= -141428) { + daysSince1970Tmp -= 10 + } + val daysNormalized = daysSince1970Tmp + toYearZero val numOfQuarterCenturies = daysNormalized / daysIn400Years val daysInThis400 = daysNormalized % daysIn400Years + 1 val (years, dayInYear) = numYears(daysInThis400) @@ -1027,7 +1046,7 @@ object DateTimeUtils { * representation in their timezone. */ def fromUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { - convertTz(time, TimeZoneGMT, TimeZone.getTimeZone(timeZone)) + convertTz(time, TimeZoneGMT, getTimeZone(timeZone)) } /** @@ -1035,7 +1054,7 @@ object DateTimeUtils { * string representation in their timezone. */ def toUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { - convertTz(time, TimeZone.getTimeZone(timeZone), TimeZoneGMT) + convertTz(time, getTimeZone(timeZone), TimeZoneGMT) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 7101ca5a17de9..45225779bffcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -70,7 +70,9 @@ object TypeUtils { def compareBinary(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) + val v1 = x(i) & 0xff + val v2 = y(i) & 0xff + val res = v1 - v2 if (res != 0) return res } x.length - y.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e1798e22b9fc..1623240de67cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -162,6 +162,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_EXECUTION_DISABLED_FOR_JOINING = + buildConf("spark.sql.adaptive.disabled.for.join") + .doc("When true, disable adaptive query execution when performing joining.") + .booleanConf + .createWithDefault(false) + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = buildConf("spark.sql.adaptive.minNumPostShufflePartitions") .internal() @@ -196,6 +202,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") + .internal() + .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + + "parser. The default is false since Spark 2.0. Setting it to true can restore the behavior " + + "prior to Spark 2.0.") + .booleanConf + .createWithDefault(false) + val PARQUET_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -295,7 +309,7 @@ object SQLConf { val HIVE_MANAGE_FILESOURCE_PARTITIONS = buildConf("spark.sql.hive.manageFilesourcePartitions") .doc("When true, enable metastore partition management for file source tables as well. " + - "This includes both datasource and converted Hive tables. When partition managment " + + "This includes both datasource and converted Hive tables. When partition management " + "is enabled, datasource tables store partition in the Hive metastore, and use the " + "metastore to prune partitions during query planning.") .booleanConf @@ -337,7 +351,8 @@ object SQLConf { .createWithDefault(true) val COLUMN_NAME_OF_CORRUPT_RECORD = buildConf("spark.sql.columnNameOfCorruptRecord") - .doc("The name of internal column for storing raw/un-parsed JSON records that fail to parse.") + .doc("The name of internal column for storing raw/un-parsed JSON and CSV records that fail " + + "to parse.") .stringConf .createWithDefault("_corrupt_record") @@ -421,6 +436,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases") + .doc("When true, aliases in a select list can be used in group by clauses. When false, " + + "an analysis exception is thrown in the case.") + .booleanConf + .createWithDefault(true) + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. val OUTPUT_COMMITTER_CLASS = @@ -521,8 +542,7 @@ object SQLConf { val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupted or non-existing and contents that have been read will still be " + - "returned.") + "encountering corrupted files and the contents that have been read will still be returned.") .booleanConf .createWithDefault(false) @@ -760,24 +780,47 @@ object SQLConf { .stringConf .createWithDefaultFunction(() => TimeZone.getDefault.getID) + val WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the window operator") + .intConf + .createWithDefault(4096) + val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.windowExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in window operator") + .doc("Threshold for number of rows to be spilled by window operator") .intConf - .createWithDefault(4096) + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the sort merge " + + "join operator") + .intConf + .createWithDefault(Int.MaxValue) val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in sort merge join operator") + .doc("Threshold for number of rows to be spilled by sort merge join operator") .intConf - .createWithDefault(Int.MaxValue) + .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") + .internal() + .doc("Threshold for number of rows guaranteed to be held in memory by the cartesian " + + "product operator") + .intConf + .createWithDefault(4096) val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold") .internal() - .doc("Threshold for number of rows buffered in cartesian product operator") + .doc("Threshold for number of rows to be spilled by cartesian product operator") .intConf .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt) @@ -870,6 +913,9 @@ class SQLConf extends Serializable with Logging { def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def adaptiveExecutionDisabledForJoining: Boolean = + getConf(ADAPTIVE_EXECUTION_DISABLED_FOR_JOINING) + def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) @@ -911,6 +957,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -1003,6 +1051,8 @@ class SQLConf extends Serializable with Logging { def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) @@ -1019,11 +1069,19 @@ class SQLConf extends Serializable with Logging { def joinReorderDPStarFilter: Boolean = getConf(SQLConf.JOIN_REORDER_DP_STAR_FILTER) + def windowExecBufferInMemoryThreshold: Int = getConf(WINDOW_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + def sortMergeJoinExecBufferInMemoryThreshold: Int = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + def cartesianProductExecBufferInMemoryThreshold: Int = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) + def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) @@ -1104,10 +1162,12 @@ class SQLConf extends Serializable with Logging { * not set yet, return `defaultValue`. */ def getConfString(key: String, defaultValue: String): String = { - val entry = sqlConfEntries.get(key) - if (entry != null && defaultValue != "") { - // Only verify configs in the SQLConf object - entry.valueConverter(defaultValue) + if (defaultValue != null && defaultValue != "") { + val entry = sqlConfEntries.get(key) + if (entry != null) { + // Only verify configs in the SQLConf object + entry.valueConverter(defaultValue) + } } Option(settings.get(key)).getOrElse(defaultValue) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index af1a9cee2962a..c6c0a605d89ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -81,4 +81,10 @@ object StaticSQLConf { "SQL configuration and the current database.") .booleanConf .createWithDefault(false) + + val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions") + .doc("Name of the class used to configure Spark Session extensions. The class should " + + "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") + .stringConf + .createOptional } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index e8f6884c025c2..1f1fb51addfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -126,20 +126,36 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + if (decimal.precision <= decimal.scale) { + // For Decimal, we expect the precision is equal to or large than the scale, however, + // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact + // result. For example, the precision of 0.01 equals to 1 based on the definition, but + // the scale is 2. The expected precision should be 3. + this._precision = decimal.scale + 1 + } else { + this._precision = decimal.precision + } this._scale = decimal.scale this } /** - * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + * If the value is not in the range of long, convert it to BigDecimal and + * the precision and scale are based on the converted value. + * + * This code avoids BigDecimal object allocation as possible to improve runtime efficiency */ def set(bigintval: BigInteger): Decimal = { - this.decimalVal = null - this.longVal = bigintval.longValueExact() - this._precision = DecimalType.MAX_PRECISION - this._scale = 0 - this + try { + this.decimalVal = null + this.longVal = bigintval.longValueExact() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } catch { + case _: ArithmeticException => + set(BigDecimal(bigintval)) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala new file mode 100644 index 0000000000000..e881685ce6262 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.internal.Logging + + +/** + * Utils for handling schemas. + * + * TODO: Merge this file with [[org.apache.spark.ml.util.SchemaUtils]]. + */ +private[spark] object SchemaUtils extends Logging { + + /** + * Checks if input column names have duplicate identifiers. Prints a warning message if + * the duplication exists. + * + * @param columnNames column names to check + * @param colType column type name, used in a warning message + * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not + */ + def checkColumnNameDuplication( + columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = { + val names = if (caseSensitiveAnalysis) { + columnNames + } else { + columnNames.map(_.toLowerCase) + } + if (names.distinct.length != names.length) { + val duplicateColumns = names.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => s"`$x`" + } + logWarning(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}. " + + "You might need to assign different column names.") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index d2ebca5a83dd3..5050318d96358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max} -import org.apache.spark.sql.catalyst.plans.{Cross, Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.{Cross, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -152,7 +153,7 @@ class AnalysisErrorSuite extends AnalysisTest { "not supported within a window function" :: Nil) errorTest( - "distinct window function", + "distinct aggregate function in window", testRelation2.select( WindowExpression( AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), @@ -162,6 +163,16 @@ class AnalysisErrorSuite extends AnalysisTest { UnspecifiedFrame)).as('window)), "Distinct window functions are not supported" :: Nil) + errorTest( + "distinct function", + CatalystSqlParser.parsePlan("SELECT hex(DISTINCT a) FROM TaBlE"), + "hex does not support the modifier DISTINCT" :: Nil) + + errorTest( + "distinct window function", + CatalystSqlParser.parsePlan("SELECT percent_rank(DISTINCT a) over () FROM TaBlE"), + "percent_rank does not support the modifier DISTINCT" :: Nil) + errorTest( "nested aggregate functions", testRelation.groupBy('a)( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 82015b1e0671c..08d9313894c2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.analysis +import java.net.URI import java.util.Locale import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf @@ -32,7 +33,10 @@ trait AnalysisTest extends PlanTest { private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) - val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf) + catalog.createDatabase( + CatalogDatabase("default", "", new URI("loc"), Map.empty), + ignoreIfExists = false) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) new Analyzer(catalog, conf) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala new file mode 100644 index 0000000000000..48a3ca2ccfb0b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DSLHintSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ + +class DSLHintSuite extends AnalysisTest { + lazy val a = 'a.int + lazy val b = 'b.string + lazy val c = 'c.string + lazy val r1 = LocalRelation(a, b, c) + + test("various hint parameters") { + comparePlans( + r1.hint("hint1"), + UnresolvedHint("hint1", Seq(), r1) + ) + + comparePlans( + r1.hint("hint1", 1, "a"), + UnresolvedHint("hint1", Seq(1, "a"), r1) + ) + + comparePlans( + r1.hint("hint1", 1, $"a"), + UnresolvedHint("hint1", Seq(1, $"a"), r1) + ) + + comparePlans( + r1.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")), + UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), r1) + ) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 8f43171f309a9..3df2530ece636 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -90,8 +90,14 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { checkType(Average(d1), DecimalType(6, 5)) checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(d1, d1), d1), DecimalType(4, 1)) + checkType(Add(d1, Add(d1, d1)), DecimalType(4, 1)) checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + checkType(Subtract(Subtract(d2, d1), d1), DecimalType(7, 2)) + checkType(Multiply(Multiply(d1, d1), d2), DecimalType(11, 4)) + checkType(Divide(d2, Add(d1, d1)), DecimalType(10, 6)) + checkType(Sum(Add(d1, d1)), DecimalType(13, 1)) } test("Comparison operations") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index d101e2227462d..3d5148008c628 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -28,68 +28,70 @@ class ResolveHintsSuite extends AnalysisTest { test("invalid hints should be ignored") { checkAnalysis( - Hint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), + UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), testRelation, caseSensitive = false) } test("case-sensitive or insensitive parameters") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("TaBlE")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = true) checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("TaBlE")), + UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), testRelation, caseSensitive = true) } test("multiple broadcast hint aliases") { checkAnalysis( - Hint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(BroadcastHint(testRelation), BroadcastHint(testRelation2), Inner, None), + UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), + Join(ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), + ResolvedHint(testRelation2, HintInfo(isBroadcastable = Option(true))), Inner, None), caseSensitive = false) } test("do not traverse past existing broadcast hints") { checkAnalysis( - Hint("MAPJOIN", Seq("table"), BroadcastHint(table("table").where('a > 1))), - BroadcastHint(testRelation.where('a > 1)).analyze, + UnresolvedHint("MAPJOIN", Seq("table"), + ResolvedHint(table("table").where('a > 1), HintInfo(isBroadcastable = Option(true)))), + ResolvedHint(testRelation.where('a > 1), HintInfo(isBroadcastable = Option(true))).analyze, caseSensitive = false) } test("should work for subqueries") { checkAnalysis( - Hint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) checkAnalysis( - Hint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), + ResolvedHint(testRelation, HintInfo(isBroadcastable = Option(true))), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("table").as("tableAlias")), + UnresolvedHint("MAPJOIN", Seq("table"), table("table").as("tableAlias")), testRelation, caseSensitive = false) } test("do not traverse past subquery alias") { checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), + UnresolvedHint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), testRelation.where('a > 1).analyze, caseSensitive = false) } @@ -102,7 +104,8 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - BroadcastHint(testRelation.where('a > 1).select('a)).select('a).analyze, + ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(isBroadcastable = Option(true))) + .select('a).analyze, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index f45a826869842..d0fe815052256 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** @@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) - val converted = ResolveInlineTables(conf).convert(table) + val withTimeZone = ResolveTimeZone(conf).apply(table) + val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] - assert(converted.output.map(_.dataType) == Seq(TimestampType)) - assert(converted.data.size == 1) - assert(converted.data(0).getLong(0) == correct) + assert(output.map(_.dataType) == Seq(TimestampType)) + assert(data.size == 1) + assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 011d09ff60641..2624f5586fd5d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest { } } + private val timeZoneResolver = ResolveTimeZone(new SQLConf) + + private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + } + test("WidenSetOperationTypes for except and intersect") { val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), @@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("p", ByteType)(), AttributeReference("q", DoubleType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val unionRelation = wt( + val unionRelation = widenSetOperationTypes( Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] assert(unionRelation.children.length == 4) checkOutput(unionRelation.children.head, expectedTypes) @@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest { } } - val dp = TypeCoercion.WidenSetOperationTypes - val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) val expectedType1 = Seq(DecimalType(10, 8)) - val r1 = dp(Union(left1, right1)).asInstanceOf[Union] - val r2 = dp(Except(left1, right1)).asInstanceOf[Except] - val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest { val plan2 = LocalRelation( AttributeReference("r", rType)()) - val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) - val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] + val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala new file mode 100644 index 0000000000000..2539ea615ff92 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import java.net.URI +import java.nio.file.{Files, Path} + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.types.StructType + +/** + * Test Suite for external catalog events + */ +class ExternalCatalogEventSuite extends SparkFunSuite { + + protected def newCatalog: ExternalCatalog = new InMemoryCatalog() + + private def testWithCatalog( + name: String)( + f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { + val catalog = newCatalog + val recorder = mutable.Buffer.empty[ExternalCatalogEvent] + catalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + recorder += event + } + }) + f(catalog, (expected: Seq[ExternalCatalogEvent]) => { + val actual = recorder.clone() + recorder.clear() + assert(expected === actual) + }) + } + + private def createDbDefinition(uri: URI): CatalogDatabase = { + CatalogDatabase(name = "db5", description = "", locationUri = uri, Map.empty) + } + + private def createDbDefinition(): CatalogDatabase = { + createDbDefinition(preparePath(Files.createTempDirectory("db_"))) + } + + private def preparePath(path: Path): URI = path.normalize().toUri + + testWithCatalog("database") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createDatabase(dbDefinition, ignoreIfExists = true) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + intercept[AnalysisException] { + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + } + checkEvents(CreateDatabasePreEvent("db5") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false) + } + checkEvents(DropDatabasePreEvent("db4") :: Nil) + + catalog.dropDatabase("db5", ignoreIfNotExists = false, cascade = false) + checkEvents(DropDatabasePreEvent("db5") :: DropDatabaseEvent("db5") :: Nil) + + catalog.dropDatabase("db4", ignoreIfNotExists = true, cascade = false) + checkEvents(DropDatabasePreEvent("db4") :: DropDatabaseEvent("db4") :: Nil) + } + + testWithCatalog("table") { (catalog, checkEvents) => + val path1 = Files.createTempDirectory("db_") + val path2 = Files.createTempDirectory(path1, "tbl_") + val uri1 = preparePath(path1) + val uri2 = preparePath(path2) + + // CREATE + val dbDefinition = createDbDefinition(uri1) + + val storage = CatalogStorageFormat.empty.copy( + locationUri = Option(uri2)) + val tableDefinition = CatalogTable( + identifier = TableIdentifier("tbl1", Some("db5")), + tableType = CatalogTableType.MANAGED, + storage = storage, + schema = new StructType().add("id", "long")) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = false) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = true) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + intercept[AnalysisException] { + catalog.createTable(tableDefinition, ignoreIfExists = false) + } + checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil) + + // RENAME + catalog.renameTable("db5", "tbl1", "tbl2") + checkEvents( + RenameTablePreEvent("db5", "tbl1", "tbl2") :: + RenameTableEvent("db5", "tbl1", "tbl2") :: Nil) + + intercept[AnalysisException] { + catalog.renameTable("db5", "tbl1", "tbl2") + } + checkEvents(RenameTablePreEvent("db5", "tbl1", "tbl2") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropTable("db5", "tbl1", ignoreIfNotExists = false, purge = true) + } + checkEvents(DropTablePreEvent("db5", "tbl1") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = false, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = true, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + } + + testWithCatalog("function") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + val functionDefinition = CatalogFunction( + identifier = FunctionIdentifier("fn7", Some("db5")), + className = "", + resources = Seq.empty) + + val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4") + val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createFunction("db5", functionDefinition) + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: CreateFunctionEvent("db5", "fn7") :: Nil) + + intercept[AnalysisException] { + catalog.createFunction("db5", functionDefinition) + } + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: Nil) + + // RENAME + catalog.renameFunction("db5", "fn7", "fn4") + checkEvents( + RenameFunctionPreEvent("db5", "fn7", "fn4") :: + RenameFunctionEvent("db5", "fn7", "fn4") :: Nil) + intercept[AnalysisException] { + catalog.renameFunction("db5", "fn7", "fn4") + } + checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropFunction("db5", "fn7") + } + checkEvents(DropFunctionPreEvent("db5", "fn7") :: Nil) + + catalog.dropFunction("db5", "fn4") + checkEvents(DropFunctionPreEvent("db5", "fn4") :: DropFunctionEvent("db5", "fn4") :: Nil) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 42db4398e5072..54ecf442a8c9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -439,6 +439,18 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty) } + test("SPARK-21457: list partitions with special chars") { + val catalog = newBasicCatalog() + assert(catalog.listPartitions("db2", "tbl1").isEmpty) + + val part1 = CatalogTablePartition(Map("a" -> "1", "b" -> "i+j"), storageFormat) + val part2 = CatalogTablePartition(Map("a" -> "1", "b" -> "i.j"), storageFormat) + catalog.createPartitions("db2", "tbl1", Seq(part1, part2), ignoreIfExists = false) + + assert(catalog.listPartitions("db2", "tbl1", Some(part1.spec)).map(_.spec) == Seq(part1.spec)) + assert(catalog.listPartitions("db2", "tbl1", Some(part2.spec)).map(_.spec) == Seq(part2.spec)) + } + test("list partitions by filter") { val tz = TimeZone.getDefault.getID val catalog = newBasicCatalog() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index be8903000a0d1..9c1b6385d049b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -498,17 +498,6 @@ abstract class SessionCatalogSuite extends PlanTest { } } - test("get option of table metadata") { - withBasicCatalog { catalog => - assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) - == Option(catalog.externalCatalog.getTable("db2", "tbl1"))) - assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) - intercept[NoSuchDatabaseException] { - catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) - } - } - } - test("lookup table relation") { withBasicCatalog { catalog => val tempTable1 = Range(1, 10, 1, 10) @@ -517,14 +506,14 @@ abstract class SessionCatalogSuite extends PlanTest { catalog.setCurrentDatabase("db2") // If we explicitly specify the database, we'll look up the relation in that database assert(catalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + .asInstanceOf[UnresolvedCatalogRelation].tableMeta == metastoreTable1) // Otherwise, we'll first look up a temporary table with the same name assert(catalog.lookupRelation(TableIdentifier("tbl1")) == SubqueryAlias("tbl1", tempTable1)) // Then, if that does not exist, look up the relation in the current database catalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) assert(catalog.lookupRelation(TableIdentifier("tbl1")).children.head - .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) + .asInstanceOf[UnresolvedCatalogRelation].tableMeta == metastoreTable1) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 080f11b769388..bb1955a1ae242 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -355,12 +355,18 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { checkNullable[String](true) } - test("null check for map key") { + test("null check for map key: String") { val encoder = ExpressionEncoder[Map[String, Int]]() val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) assert(e.getMessage.contains("Cannot use null as map key")) } + test("null check for map key: Integer") { + val encoder = ExpressionEncoder[Map[Integer, String]]() + val e = intercept[RuntimeException](encoder.toRow(Map((1, "a"), (null, "b")))) + assert(e.getMessage.contains("Cannot use null as map key")) + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 9978f35a03810..257c2a3bef974 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -76,6 +76,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) + + checkEvaluation(DayOfYear(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 288) + checkEvaluation(DayOfYear(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 277) checkConsistencyBetweenInterpretedAndCodegen(DayOfYear, DateType) } @@ -96,6 +99,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkEvaluation(Year(Literal(new Date(sdf.parse("1582-01-01 13:10:15").getTime))), 1582) + checkEvaluation(Year(Literal(new Date(sdf.parse("1581-12-31 13:10:15").getTime))), 1581) checkConsistencyBetweenInterpretedAndCodegen(Year, DateType) } @@ -116,6 +121,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + checkEvaluation(Quarter(Literal(new Date(sdf.parse("1582-10-01 13:10:15").getTime))), 4) + checkEvaluation(Quarter(Literal(new Date(sdf.parse("1582-09-30 13:10:15").getTime))), 3) checkConsistencyBetweenInterpretedAndCodegen(Quarter, DateType) } @@ -125,6 +133,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 4) checkEvaluation(Month(Cast(Literal(ts), DateType, gmtId)), 11) + checkEvaluation(Month(Literal(new Date(sdf.parse("1582-04-28 13:10:15").getTime))), 4) + checkEvaluation(Month(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 10) + checkEvaluation(Month(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 10) + val c = Calendar.getInstance() (2003 to 2004).foreach { y => (0 to 3).foreach { m => @@ -146,6 +158,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 8) checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType, gmtId)), 8) + checkEvaluation(DayOfMonth(Literal(new Date(sdf.parse("1582-04-28 13:10:15").getTime))), 28) + checkEvaluation(DayOfMonth(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 15) + checkEvaluation(DayOfMonth(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 4) + val c = Calendar.getInstance() (1999 to 2000).foreach { y => c.set(y, 0, 1, 0, 0, 0) @@ -160,7 +176,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Seconds") { assert(Second(Literal.create(null, DateType), gmtId).resolved === false) - assert(Second(Cast(Literal(d), TimestampType), None).resolved === true) + assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true) checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15) checkEvaluation(Second(Literal(ts), gmtId), 15) @@ -186,6 +202,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType, gmtId)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType, gmtId)), 18) + checkEvaluation(WeekOfYear(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 40) + checkEvaluation(WeekOfYear(Literal(new Date(sdf.parse("1582-10-04 13:10:15").getTime))), 40) checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType) } @@ -220,7 +238,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Hour") { assert(Hour(Literal.create(null, DateType), gmtId).resolved === false) - assert(Hour(Literal(ts), None).resolved === true) + assert(Hour(Literal(ts), gmtId).resolved === true) checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13) checkEvaluation(Hour(Literal(ts), gmtId), 13) @@ -246,7 +264,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Minute") { assert(Minute(Literal.create(null, DateType), gmtId).resolved === false) - assert(Minute(Literal(ts), None).resolved === true) + assert(Minute(Literal(ts), gmtId).resolved === true) checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation( Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1ba6dd1c5e8ca..b6399edb68dd6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val serializer = new JavaSerializer(new SparkConf()).newInstance - val expr: Expression = serializer.deserialize(serializer.serialize(expression)) + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index c5b72235e5db0..53b54de606930 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -21,6 +21,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -39,6 +40,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { |"fb:testid":"1234"} |""".stripMargin + /* invalid json with leading nulls would trigger java.io.CharConversionException + in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ + val badJson = "\u0000\u0000\u0000A\u0001AAA" + test("$.store.bicycle") { checkEvaluation( GetJsonObject(Literal(json), Literal("$.store.bicycle")), @@ -224,6 +229,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) } + test("SPARK-16548: character conversion") { + checkEvaluation( + GetJsonObject(Literal(badJson), Literal("$.a")), + null + ) + } + test("non foldable literal") { checkEvaluation( GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")), @@ -340,6 +352,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(null, null, null, null, null)) } + test("SPARK-16548: json_tuple - invalid json with leading nulls") { + checkJsonTuple( + JsonTuple(Literal(badJson) :: jsonTupleQuery), + InternalRow(null, null, null, null, null)) + } + test("json_tuple - preserve newlines") { checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), @@ -436,6 +454,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("SPARK-20549: from_json bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + null) + } + test("from_json with timestamp") { val schema = StructType(StructField("t", TimestampType) :: Nil) @@ -566,4 +591,26 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { """{"t":"2015-12-31T16:00:00"}""" ) } + + test("to_json: verify MapType's value type instead of key type") { + // Keys in map are treated as strings when converting to JSON. The type doesn't matter at all. + val mapType1 = MapType(CalendarIntervalType, IntegerType) + val schema1 = StructType(StructField("a", mapType1) :: Nil) + val struct1 = Literal.create(null, schema1) + checkEvaluation( + StructsToJson(Map.empty, struct1, gmtId), + null + ) + + // The value type must be valid for converting to JSON. + val mapType2 = MapType(IntegerType, CalendarIntervalType) + val schema2 = StructType(StructField("a", mapType2) :: Nil) + val struct2 = Literal.create(null, schema2) + intercept[TreeNodeException[_]] { + checkEvaluation( + StructsToJson(Map.empty, struct2, gmtId), + null + ) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6b5bfac94645c..69ada8216515d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -252,6 +252,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Ceil(doublePi), 4L, EmptyRow) + checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) + checkEvaluation(Ceil(longLit), longLit, EmptyRow) + checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) + checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) + checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) } test("floor") { @@ -262,6 +272,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Floor(doublePi), 3L, EmptyRow) + checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) + checkEvaluation(Floor(longLit), longLit, EmptyRow) + checkEvaluation(Floor(-doublePi), -4L, EmptyRow) + checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) + checkEvaluation(Floor(-longLit), -longLit, EmptyRow) } test("factorial") { @@ -546,15 +566,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), BigDecimal(3.141593), BigDecimal(3.1415927)) - // round_scale > current_scale would result in precision increase - // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) } (8 to 10).foreach { scale => - checkEvaluation(Round(bdPi, scale), null, EmptyRow) - checkEvaluation(BRound(bdPi, scale), null, EmptyRow) + checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow) + checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow) } DataTypeTestUtils.numericTypes.foreach { dataType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index 190fab5d249bb..aa61ba2bff2bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -137,4 +137,23 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { // verify that we can support up to 5000 ordering comparisons, which should be sufficient GenerateOrdering.generate(Array.fill(5000)(sortOrder)) } + + test("SPARK-21344: BinaryType comparison does signed byte array comparison") { + val data = Seq( + (Array[Byte](1), Array[Byte](-1)), + (Array[Byte](1, 1, 1, 1, 1), Array[Byte](1, 1, 1, 1, -1)), + (Array[Byte](1, 1, 1, 1, 1, 1, 1, 1, 1), Array[Byte](1, 1, 1, 1, 1, 1, 1, 1, -1)) + ) + data.foreach { case (b1, b2) => + val rowOrdering = InterpretedOrdering.forSchema(Seq(BinaryType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, BinaryType, nullable = true).asc :: Nil) + val rowType = StructType(StructField("b", BinaryType, nullable = true) :: Nil) + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + val rowB1 = toCatalyst(Row(b1)).asInstanceOf[InternalRow] + val rowB2 = toCatalyst(Row(b2)).asInstanceOf[InternalRow] + assert(rowOrdering.compare(rowB1, rowB2) < 0) + assert(genOrdering.compare(rowB1, rowB2) < 0) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 6fe295c3dd936..ef510a95ef446 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test(s"3VL $name") { truthTable.foreach { case (l, r, answer) => - val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType)) + val expr = op(NonFoldableLiteral.create(l, BooleanType), + NonFoldableLiteral.create(r, BooleanType)) checkEvaluation(expr, answer) } } @@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (false, true) :: (null, null) :: Nil notTrueTable.foreach { case (v, answer) => - checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer) + checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer) } checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType) } @@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("IN") { - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), - Seq(NonFoldableLiteral(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + Literal(2))), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + true) + checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) checkEvaluation( - And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), + And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), + Literal(2)))), true) - val ns = NonFoldableLiteral(null, StringType) + val ns = NonFoldableLiteral.create(null, StringType) checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) checkEvaluation(In(ns, Seq(ns)), null) checkEvaluation(In(Literal("a"), Seq(ns)), null) @@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { case _ => value } } - val input = inputData.map(NonFoldableLiteral(_, t)) + val input = inputData.map(NonFoldableLiteral.create(_, t)) val expected = if (inputData(0) == null) { null } else if (inputData.slice(1, 10).contains(inputData(0))) { @@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { test("BinaryComparison: null test") { // Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757 val normalInt = Literal(-1) - val nullInt = NonFoldableLiteral(null, IntegerType) + val nullInt = NonFoldableLiteral.create(null, IntegerType) def nullTest(op: (Expression, Expression) => Expression): Unit = { checkEvaluation(op(normalInt, nullInt), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala new file mode 100644 index 0000000000000..e9d21f8a8ebcd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +class GenerateUnsafeProjectionSuite extends SparkFunSuite { + test("Test unsafe projection string access pattern") { + val dataType = (new StructType).add("a", StringType) + val exprs = BoundReference(0, dataType, nullable = true) :: Nil + val projection = GenerateUnsafeProjection.generate(exprs) + val result = projection.apply(InternalRow(AlwaysNull)) + assert(!result.isNullAt(0)) + assert(result.getStruct(0, 1).isNullAt(0)) + } +} + +object AlwaysNull extends InternalRow { + override def numFields: Int = 1 + override def setNullAt(i: Int): Unit = {} + override def copy(): InternalRow = this + override def anyNull: Boolean = true + override def isNullAt(ordinal: Int): Boolean = true + override def update(i: Int, value: Any): Unit = notSupported + override def getBoolean(ordinal: Int): Boolean = notSupported + override def getByte(ordinal: Int): Byte = notSupported + override def getShort(ordinal: Int): Short = notSupported + override def getInt(ordinal: Int): Int = notSupported + override def getLong(ordinal: Int): Long = notSupported + override def getFloat(ordinal: Int): Float = notSupported + override def getDouble(ordinal: Int): Double = notSupported + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported + override def getUTF8String(ordinal: Int): UTF8String = notSupported + override def getBinary(ordinal: Int): Array[Byte] = notSupported + override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported + override def getArray(ordinal: Int): ArrayData = notSupported + override def getMap(ordinal: Int): MapData = notSupported + override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported + private def notSupported: Nothing = throw new UnsupportedOperationException +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 935bff7cef2e8..c275f997ba6e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.Row class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -42,6 +43,16 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelationWithData = LocalRelation.fromExternalRows( + testRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { + val plan = testRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze val actual = Optimize.execute(plan) @@ -160,4 +171,12 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } + + test("Complementation Laws") { + checkCondition('a && !'a, testRelation) + checkCondition(!'a && 'a, testRelation) + + checkCondition('a || !'a, testRelationWithData) + checkCondition(!'a || 'a, testRelationWithData) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 589607e3ad5cb..a0a0daea7d075 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -321,15 +321,14 @@ class ColumnPruningSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( SubqueryAlias("x", input), - BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + ResolvedHint(SubqueryAlias("y", input)), Inner, None)).analyze val optimized = Optimize.execute(query) val expected = Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - BroadcastHint( - Project(Seq($"y.key"), SubqueryAlias("y", input))), + ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Inner, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala new file mode 100644 index 0000000000000..d4f37e2a5e877 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class EliminateMapObjectsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = { + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + SimplifyCasts, + EliminateMapObjects) :: Nil + } + } + + implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() + implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() + + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val intObjType = ObjectType(classOf[Array[Int]]) + val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intQuery = intInput.deserialize[Array[Int]].analyze + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + + val doubleObjType = ObjectType(classOf[Array[Double]]) + val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleQuery = doubleInput.deserialize[Array[Double]].analyze + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 950aa2379517e..d4d281e7e05db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -798,12 +798,12 @@ class FilterPushdownSuite extends PlanTest { } test("broadcast hint") { - val originalQuery = BroadcastHint(testRelation) + val originalQuery = ResolvedHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = BroadcastHint(testRelation.where('a === 2L)) + val correctAnswer = ResolvedHint(testRelation.where('a === 2L)) .where('b + Rand(10).as("rnd") === 3) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index c8fe37462726a..9a4bcdb011435 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -33,7 +33,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushPredicateThroughJoin, PushDownPredicate, InferFiltersFromConstraints(conf), - CombineFilters) :: Nil + CombineFilters, + BooleanSimplification) :: Nil } object OptimizeWithConstraintPropagationDisabled extends RuleExecutor[LogicalPlan] { @@ -172,7 +173,12 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val t1 = testRelation.subquery('t1) val t2 = testRelation.subquery('t2) - val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + // We should prevent `Coalese(a, b)` from recursively creating complicated constraints through + // the constraint inference procedure. + val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) + // We hide an `Alias` inside the child's child's expressions, to cover the situation reported + // in [SPARK-20700]. + .select('int_col, 'd, 'a).as("t") .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr @@ -180,22 +186,18 @@ class InferFiltersFromConstraintsSuite extends PlanTest { .analyze val correctAnswer = t1 .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a - && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) - && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) - && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) + && Coalesce(Seq('b, 'b)) <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) + && 'a === Coalesce(Seq('a, 'b)) && Coalesce(Seq('a, 'b)) === 'b && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) - && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) - && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) - .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))) + .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)) + .select('int_col, 'd, 'a).as("t") .join(t2 .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a - && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, - Some("t.a".attr === "t2.a".attr - && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr - && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr)) + && 'a <=> Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner, + Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a43d78c7bd447..105407d43bf39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -129,14 +129,14 @@ class JoinOptimizationSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( SubqueryAlias("x", input), - BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze + ResolvedHint(SubqueryAlias("y", input)), Cross, None)).analyze val optimized = Optimize.execute(query) val expected = Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), + ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Cross, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index fdde89d079bc0..50398788c605c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst.optimizer -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{BooleanType, StringType} class LikeSimplificationSuite extends PlanTest { @@ -100,4 +100,10 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("null pattern") { + val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index d8937321ecb98..f12f0f5eb4cd4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -166,7 +166,7 @@ class OptimizeInSuite extends PlanTest { val optimizedPlan = OptimizeIn(conf.copy(OPTIMIZER_INSET_CONVERSION_THRESHOLD -> 2))(plan) optimizedPlan match { case Filter(cond, _) - if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 => + if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 => // pass case _ => fail("Unexpected result for OptimizedIn") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index c261a6091d476..38dff4733f714 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -142,7 +142,7 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer.analyze) } - test("propagate empty relation through Aggregate without aggregate function") { + test("propagate empty relation through Aggregate with grouping expressions") { val query = testRelation1 .where(false) .groupBy('a)('a, ('a + 1).as('x)) @@ -153,13 +153,13 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("don't propagate empty relation through Aggregate with aggregate function") { + test("don't propagate empty relation through Aggregate without grouping expressions") { val query = testRelation1 .where(false) - .groupBy('a)(count('a)) + .groupBy()() val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int).groupBy('a)(count('a)).analyze + val correctAnswer = LocalRelation('a.int).groupBy()().analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 3964fa3924b24..4490523369006 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -30,7 +30,7 @@ class DataTypeParserSuite extends SparkFunSuite { } } - def intercept(sql: String): Unit = + def intercept(sql: String): ParseException = intercept[ParseException](CatalystSqlParser.parseDataType(sql)) def unsupported(dataTypeString: String): Unit = { @@ -118,6 +118,11 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("struct") + test("Do not print empty parentheses for no params") { + assert(intercept("unkwon").getMessage.contains("unkwon is not supported")) + assert(intercept("unkwon(1,2,3)").getMessage.contains("unkwon(1,2,3) is not supported")) + } + // DataType parser accepts certain reserved keywords. checkDataType( "Struct", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e7f3b64a71130..f06219198bb58 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -39,12 +40,17 @@ class ExpressionParserSuite extends PlanTest { import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ - def assertEqual(sqlCommand: String, e: Expression): Unit = { - compareExpressions(parseExpression(sqlCommand), e) + val defaultParser = CatalystSqlParser + + def assertEqual( + sqlCommand: String, + e: Expression, + parser: ParserInterface = defaultParser): Unit = { + compareExpressions(parser.parseExpression(sqlCommand), e) } def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parseExpression(sqlCommand)) + val e = intercept[ParseException](defaultParser.parseExpression(sqlCommand)) messages.foreach { message => assert(e.message.contains(message)) } @@ -101,7 +107,7 @@ class ExpressionParserSuite extends PlanTest { test("long binary logical expressions") { def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) - val e = parseExpression(sql) + val e = defaultParser.parseExpression(sql) assert(e.collect { case _: EqualTo => true }.size === 1000) assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) } @@ -160,6 +166,15 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) } + test("like expressions with ESCAPED_STRING_LITERALS = true") { + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true") + val parser = new CatalystSqlParser(conf) + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + } + test("is null expressions") { assertEqual("a is null", 'a.isNull) assertEqual("a is not null", 'a.isNotNull) @@ -211,7 +226,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) assertEqual("`select`(all a, b)", 'select.function('a, 'b)) - assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e)) + intercept("foo(a x)", "extraneous input 'x'") } test("window function expressions") { @@ -310,7 +325,9 @@ class ExpressionParserSuite extends PlanTest { assertEqual("a.b", UnresolvedAttribute("a.b")) assertEqual("`select`.b", UnresolvedAttribute("select.b")) assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. - assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + assertEqual( + "struct(a, b).b", + namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b")) } test("reference") { @@ -413,38 +430,87 @@ class ExpressionParserSuite extends PlanTest { } test("strings") { - // Single Strings. - assertEqual("\"hello\"", "hello") - assertEqual("'hello'", "hello") - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld") - assertEqual("'hello' \" \" 'world'", "hello world") - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%") - assertEqual("'no-pattern\\%'", "no-pattern\\%") - assertEqual("'pattern\\\\%'", "pattern\\%") - assertEqual("'pattern\\\\\\%'", "pattern\\\\%") - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') - assertEqual("'\\''", "\'") // Single quote - assertEqual("'\\\"'", "\"") // Double quote - assertEqual("'\\b'", "\b") // Backspace - assertEqual("'\\n'", "\n") // Newline - assertEqual("'\\r'", "\r") // Carriage return - assertEqual("'\\t'", "\t") // Tab character - assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") - - // Unicode - assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") + Seq(true, false).foreach { escape => + val conf = new SQLConf() + conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString) + val parser = new CatalystSqlParser(conf) + + // tests that have same result whatever the conf is + // Single Strings. + assertEqual("\"hello\"", "hello", parser) + assertEqual("'hello'", "hello", parser) + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld", parser) + assertEqual("'hello' \" \" 'world'", "hello world", parser) + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%", parser) + assertEqual("'no-pattern\\%'", "no-pattern\\%", parser) + + // tests that have different result regarding the conf + if (escape) { + // When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to + // Spark 1.6 behavior. + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) + + // Escaped characters. + // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work + // when ESCAPED_STRING_LITERALS is enabled. + // It is parsed literally. + assertEqual("'\\0'", "\\0", parser) + + // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. + val e = intercept[ParseException](parser.parseExpression("'\''")) + assert(e.message.contains("extraneous input '''")) + + // The unescape special characters (e.g., "\\t") for 2.0+ don't work + // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. + assertEqual("'\\\"'", "\\\"", parser) // Double quote + assertEqual("'\\b'", "\\b", parser) // Backspace + assertEqual("'\\n'", "\\n", parser) // Newline + assertEqual("'\\r'", "\\r", parser) // Carriage return + assertEqual("'\\t'", "\\t", parser) // Tab character + + // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) + // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", + "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) + } else { + // Default behavior + + // 'LIKE' string literals. + assertEqual("'pattern\\\\%'", "pattern\\%", parser) + assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser) + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00') + assertEqual("'\\''", "\'", parser) // Single quote + assertEqual("'\\\"'", "\"", parser) // Double quote + assertEqual("'\\b'", "\b", parser) // Backspace + assertEqual("'\\n'", "\n", parser) // Newline + assertEqual("'\\r'", "\r", parser) // Carriage return + assertEqual("'\\t'", "\t", parser) // Tab character + assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser) + + // Unicode + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)", + parser) + } + + } } test("intervals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411777d6e85a2..950f152b94b4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -176,14 +176,14 @@ class PlanParserSuite extends PlanTest { def insert( partition: Map[String, Option[String]], overwrite: Boolean = false, - ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + ifPartitionNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifPartitionNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", insert(Map.empty, overwrite = true)) assertEqual(s"insert overwrite table s partition (e = 1) if not exists $sql", - insert(Map("e" -> Option("1")), overwrite = true, ifNotExists = true)) + insert(Map("e" -> Option("1")), overwrite = true, ifPartitionNotExists = true)) assertEqual(s"insert into s $sql", insert(Map.empty)) assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", @@ -193,9 +193,9 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), false, ifPartitionNotExists = false).union( InsertIntoTable( - table("u"), Map.empty, plan2, false, ifNotExists = false))) + table("u"), Map.empty, plan2, false, ifPartitionNotExists = false))) } test ("insert with if not exists") { @@ -223,6 +223,12 @@ class PlanParserSuite extends PlanTest { assertEqual(s"$sql grouping sets((a, b), (a), ())", GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + + val m = intercept[ParseException] { + parsePlan("SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b") + }.getMessage + assert(m.contains("extraneous input 'b'")) + } test("limit") { @@ -496,46 +502,109 @@ class PlanParserSuite extends PlanTest { val m = intercept[ParseException] { parsePlan("SELECT /*+ HINT() */ * FROM t") }.getMessage - assert(m.contains("no viable alternative at input")) - - // Hive compatibility: No database. - val m2 = intercept[ParseException] { - parsePlan("SELECT /*+ MAPJOIN(default.t) */ * from default.t") - }.getMessage - assert(m2.contains("mismatched input '.' expecting {')', ','}")) + assert(m.contains("mismatched input")) // Disallow space as the delimiter. val m3 = intercept[ParseException] { parsePlan("SELECT /*+ INDEX(a b c) */ * from default.t") }.getMessage - assert(m3.contains("mismatched input 'b' expecting {')', ','}")) + assert(m3.contains("mismatched input 'b' expecting")) comparePlans( parsePlan("SELECT /*+ HINT */ * FROM t"), - Hint("HINT", Seq.empty, table("t").select(star()))) + UnresolvedHint("HINT", Seq.empty, table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), - Hint("BROADCASTJOIN", Seq("u"), table("t").select(star()))) + UnresolvedHint("BROADCASTJOIN", Seq($"u"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), - Hint("MAPJOIN", Seq("u"), table("t").select(star()))) + UnresolvedHint("MAPJOIN", Seq($"u"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), - Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star()))) + UnresolvedHint("STREAMTABLE", Seq($"a", $"b", $"c"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"), - Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star()))) + UnresolvedHint("INDEX", Seq($"t", $"emp_job_ix"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), - Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star()))) + UnresolvedHint("MAPJOIN", Seq(UnresolvedAttribute.quoted("default.t")), + table("default.t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), - Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + UnresolvedHint("MAPJOIN", Seq($"t"), + table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + } + + test("SPARK-20854: select hint syntax with expressions") { + comparePlans( + parsePlan("SELECT /*+ HINT1(a, array(1, 2, 3)) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", + UnresolvedFunction("array", Literal(1) :: Literal(2) :: Literal(3) :: Nil, false)), + table("t").select(star()) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 5, 'a', b) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(5), Literal("a"), $"b"), + table("t").select(star()) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1('a', (b, c), (1, 2)) */ * from t"), + UnresolvedHint("HINT1", + Seq(Literal("a"), + CreateStruct($"b" :: $"c" :: Nil), + CreateStruct(Literal(1) :: Literal(2) :: Nil)), + table("t").select(star()) + ) + ) + } + + test("SPARK-20854: multiple hints") { + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1) hint2(b, 2) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + table("t").select(star()) + ) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1),hint2(b, 2) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + table("t").select(star()) + ) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1) */ /*+ hint2(b, 2) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + table("t").select(star()) + ) + ) + ) + + comparePlans( + parsePlan("SELECT /*+ HINT1(a, 1), hint2(b, 2) */ /*+ hint3(c, 3) */ * from t"), + UnresolvedHint("HINT1", Seq($"a", Literal(1)), + UnresolvedHint("hint2", Seq($"b", Literal(2)), + UnresolvedHint("hint3", Seq($"c", Literal(3)), + table("t").select(star()) + ) + ) + ) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 467f76193cfc5..7c8ed78a49116 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union} import org.apache.spark.sql.catalyst.util._ /** @@ -66,4 +66,10 @@ class SameResultSuite extends SparkFunSuite { assertSameResult(Union(Seq(testRelation, testRelation2)), Union(Seq(testRelation2, testRelation))) } + + test("hint") { + val df1 = testRelation.join(ResolvedHint(testRelation)) + val df2 = testRelation.join(testRelation) + assertSameResult(df1, df2) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index b06871f96f0d8..2afea6dd3d37c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -37,19 +37,20 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { test("BroadcastHint estimation") { val filter = Filter(Literal(true), plan) - val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false, + val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat))) - val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false) + val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4)) checkStats( filter, expectedStatsCboOn = filterStatsCboOn, expectedStatsCboOff = filterStatsCboOff) - val broadcastHint = BroadcastHint(filter) + val broadcastHint = ResolvedHint(filter, HintInfo(isBroadcastable = Option(true))) checkStats( broadcastHint, - expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true), - expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true)) + expectedStatsCboOn = filterStatsCboOn.copy(hints = HintInfo(isBroadcastable = Option(true))), + expectedStatsCboOff = filterStatsCboOff.copy(hints = HintInfo(isBroadcastable = Option(true))) + ) } test("limit estimation: limit < child's rowCount") { @@ -94,15 +95,13 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 40, rowCount = Some(10), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), - isBroadcastable = false) + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4)))) val expectedCboStats = Statistics( sizeInBytes = 4, rowCount = Some(1), attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), - isBroadcastable = false) + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4)))) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) checkStats( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index a28447840ae09..2fa53a6466ef2 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -150,7 +150,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt), + Seq(attrInt -> colStatInt.copy(distinctCount = 3)), expectedRowCount = 3) } @@ -158,7 +158,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt), + Seq(attrInt -> colStatInt.copy(distinctCount = 8)), expectedRowCount = 8) } @@ -174,7 +174,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, IntegerType)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt), + Seq(attrInt -> colStatInt.copy(distinctCount = 8)), expectedRowCount = 8) } @@ -205,7 +205,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3") { validateEstimatedStats( Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 3) } @@ -221,7 +221,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint <= 3") { validateEstimatedStats( Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(3), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 3) } @@ -229,7 +229,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint > 6") { validateEstimatedStats( Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 5) } @@ -245,7 +245,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint >= 6") { validateEstimatedStats( Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + Seq(attrInt -> ColumnStat(distinctCount = 5, min = Some(6), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 5) } @@ -279,7 +279,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(3), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 4) } @@ -288,8 +288,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> colStatInt.copy(distinctCount = 2)), expectedRowCount = 2) } @@ -297,7 +296,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt), + Seq(attrInt -> colStatInt.copy(distinctCount = 6)), expectedRowCount = 6) } @@ -305,7 +304,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> colStatInt), + Seq(attrInt -> colStatInt.copy(distinctCount = 5)), expectedRowCount = 5) } @@ -321,7 +320,8 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), - Seq(attrInt -> colStatInt, attrString -> colStatString), + Seq(attrInt -> colStatInt.copy(distinctCount = 9), + attrString -> colStatString.copy(distinctCount = 9)), expectedRowCount = 9) } @@ -336,8 +336,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4)), + Seq(attrInt -> colStatInt.copy(distinctCount = 7)), expectedRowCount = 7) } @@ -380,7 +379,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { validateEstimatedStats( Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), - Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 3) } @@ -421,7 +420,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cdouble < 3.0") { validateEstimatedStats( Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), - Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + Seq(attrDouble -> ColumnStat(distinctCount = 3, min = Some(1.0), max = Some(3.0), nullCount = 0, avgLen = 8, maxLen = 8)), expectedRowCount = 3) } @@ -487,9 +486,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(EqualTo(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 4) } @@ -498,9 +497,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(GreaterThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(10), + attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 4) } @@ -509,9 +508,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt2), childStatsTestPlan(Seq(attrInt, attrInt2), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - attrInt2 -> ColumnStat(distinctCount = 3, min = Some(7), max = Some(16), + attrInt2 -> ColumnStat(distinctCount = 4, min = Some(7), max = Some(16), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 4) } @@ -531,9 +530,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // partial overlap case validateEstimatedStats( Filter(LessThan(attrInt, attrInt4), childStatsTestPlan(Seq(attrInt, attrInt4), 10L)), - Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4), - attrInt4 -> ColumnStat(distinctCount = 3, min = Some(1), max = Some(10), + attrInt4 -> ColumnStat(distinctCount = 4, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 4) } @@ -565,6 +564,20 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 0) } + test("update ndv for columns based on overall selectivity") { + // filter condition: cint > 3 AND cint4 <= 6 + val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt4, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt, attrInt4, attrString), 10L)), + Seq( + attrInt -> ColumnStat(distinctCount = 5, min = Some(3), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attrInt4 -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4), + attrString -> colStatString.copy(distinctCount = 5)), + expectedRowCount = 5) + } + private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 37e3dfabd0b21..06ef7bcee0d84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -54,13 +54,21 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } -case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { +case class ExpressionInMap(map: Map[String, Expression]) extends Unevaluable { override def children: Seq[Expression] = map.values.toSeq override def nullable: Boolean = true override def dataType: NullType = NullType override lazy val resolved = true } +case class SeqTupleExpression(sons: Seq[(Expression, Expression)], + nonSons: Seq[(Expression, Expression)]) extends Unevaluable { + override def children: Seq[Expression] = sons.flatMap(t => Iterator(t._1, t._2)) + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + case class JsonTestTreeNode(arg: Any) extends LeafNode { override def output: Seq[Attribute] = Seq.empty[Attribute] } @@ -146,6 +154,17 @@ class TreeNodeSuite extends SparkFunSuite { assert(actual === Dummy(None)) } + test("mapChildren should only works on children") { + val children = Seq((Literal(1), Literal(2))) + val nonChildren = Seq((Literal(3), Literal(4))) + val before = SeqTupleExpression(children, nonChildren) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val expect = SeqTupleExpression(Seq((Literal(0), Literal(0))), nonChildren) + + val actual = before mapChildren toZero + assert(actual === expect) + } + test("preserves origin") { CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 9799817494f15..c8cf16d937352 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -34,6 +34,22 @@ class DateTimeUtilsSuite extends SparkFunSuite { ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } + test("nanoseconds truncation") { + def checkStringToTimestamp(originalTime: String, expectedParsedTime: String) { + val parsedTimestampOp = DateTimeUtils.stringToTimestamp(UTF8String.fromString(originalTime)) + assert(parsedTimestampOp.isDefined, "timestamp with nanoseconds was not parsed correctly") + assert(DateTimeUtils.timestampToString(parsedTimestampOp.get) === expectedParsedTime) + } + + checkStringToTimestamp("2015-01-02 00:00:00.123456789", "2015-01-02 00:00:00.123456") + checkStringToTimestamp("2015-01-02 00:00:00.100000009", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000050000", "2015-01-02 00:00:00.00005") + checkStringToTimestamp("2015-01-02 00:00:00.12005", "2015-01-02 00:00:00.12005") + checkStringToTimestamp("2015-01-02 00:00:00.100", "2015-01-02 00:00:00.1") + checkStringToTimestamp("2015-01-02 00:00:00.000456789", "2015-01-02 00:00:00.000456") + checkStringToTimestamp("1950-01-02 00:00:00.000456789", "1950-01-02 00:00:00.000456") + } + test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 714883a4099cf..144f3d688d402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1) + checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) @@ -212,4 +222,10 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { } } } + + test("SPARK-20341: support BigInt's value does not fit in long value range") { + val bigInt = scala.math.BigInt("9223372036854775808") + val decimal = Decimal.apply(bigInt) + assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index b203f31a76f03..bee0434c57f00 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index eb97118872ea1..5a810cae1e184 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -66,7 +66,6 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType$; import org.apache.spark.util.AccumulatorV2; -import org.apache.spark.util.LongAccumulator; /** * Base class for custom RecordReaders for Parquet that directly materialize to `T`. @@ -153,14 +152,16 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont } // For test purpose. - // If the predefined accumulator exists, the row group number to read will be updated - // to the accumulator. So we can check if the row groups are filtered or not in test case. + // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read + // will be updated to the accumulator. So we can check if the row groups are filtered or not + // in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option> accu = taskContext.taskMetrics() - .lookForAccumulatorByName("numRowGroups"); - if (accu.isDefined()) { - ((LongAccumulator)accu.get()).add((long)blocks.size()); + Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); + if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { + @SuppressWarnings("unchecked") + AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); + intAccum.add(blocks.size()); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 354c878aca000..b105e60a2d34a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -180,7 +180,7 @@ public Object[] array() { @Override public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); + return data.getBoolean(offset + ordinal); } @Override @@ -188,7 +188,7 @@ public boolean getBoolean(int ordinal) { @Override public short getShort(int ordinal) { - throw new UnsupportedOperationException(); + return data.getShort(offset + ordinal); } @Override @@ -199,7 +199,7 @@ public short getShort(int ordinal) { @Override public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); + return data.getFloat(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index a6ce4c2edc232..8b7b0e655b31d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -198,21 +198,25 @@ public boolean anyNull() { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; final int months = columns[ordinal].getChildColumn(0).getInt(rowId); final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); return new CalendarInterval(months, microseconds); @@ -220,11 +224,13 @@ public CalendarInterval getInterval(int ordinal) { @Override public InternalRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ArrayData getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e988c0722bd72..cda7f2fe23815 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -436,28 +436,29 @@ public void loadBytes(ColumnVector.Array array) { // Split out the slow path. @Override protected void reserveInternal(int newCapacity) { + int oldCapacity = (nulls == 0L) ? 0 : capacity; if (this.resultArray != null) { this.lengthData = - Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = - Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); } else if (type instanceof ByteType || type instanceof BooleanType) { - this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + this.nulls = Platform.reallocateMemory(nulls, oldCapacity, newCapacity); + Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity); capacity = newCapacity; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 9b410bacff5df..94ed32294cfae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -410,53 +410,53 @@ protected void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, capacity); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, capacity); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } } else if (resultStruct != null) { @@ -466,7 +466,7 @@ protected void reserveInternal(int newCapacity) { } byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, capacity); nulls = newNulls; capacity = newCapacity; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java rename to sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java index 3e3997fa9bfec..d31790a285687 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.java +++ b/sql/core/src/main/java/org/apache/spark/sql/streaming/Trigger.java @@ -21,22 +21,18 @@ import scala.concurrent.duration.Duration; -import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.streaming.OneTimeTrigger$; /** - * :: Experimental :: * Policy used to indicate how often results should be produced by a [[StreamingQuery]]. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving public class Trigger { /** - * :: Experimental :: * A trigger policy that runs a query periodically based on an interval in processing time. * If `interval` is 0, the query will run as fast as possible. * @@ -47,7 +43,6 @@ public static Trigger ProcessingTime(long intervalMs) { } /** - * :: Experimental :: * (Java-friendly) * A trigger policy that runs a query periodically based on an interval in processing time. * If `interval` is 0, the query will run as fast as possible. @@ -64,7 +59,6 @@ public static Trigger ProcessingTime(long interval, TimeUnit timeUnit) { } /** - * :: Experimental :: * (Scala-friendly) * A trigger policy that runs a query periodically based on an interval in processing time. * If `duration` is 0, the query will run as fast as possible. @@ -80,7 +74,6 @@ public static Trigger ProcessingTime(Duration interval) { } /** - * :: Experimental :: * A trigger policy that runs a query periodically based on an interval in processing time. * If `interval` is effectively 0, the query will run as fast as possible. * diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 27d32b5dca431..0c5f3f22e31e8 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -5,3 +5,4 @@ org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat org.apache.spark.sql.execution.datasources.text.TextFileFormat org.apache.spark.sql.execution.streaming.ConsoleSinkProvider org.apache.spark.sql.execution.streaming.TextSocketSourceProvider +org.apache.spark.sql.execution.streaming.RateSourceProvider diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 052d85ad33bd6..1d88992c48562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -244,13 +244,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", ImmutableMap.of(1.0, 2.0)); + * df.na.replace("height", ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -271,10 +271,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * import com.google.common.collect.ImmutableMap; * * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); + * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); + * df.na.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed")); * }}} * * @param cols list of columns to apply the value replacement @@ -295,13 +295,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height". - * df.replace("height", Map(1.0 -> 2.0)) + * df.na.replace("height", Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name". - * df.replace("name", Map("UNKNOWN" -> "unnamed") + * df.na.replace("name", Map("UNKNOWN" -> "unnamed")); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns. - * df.replace("*", Map("UNKNOWN" -> "unnamed") + * df.na.replace("*", Map("UNKNOWN" -> "unnamed")); * }}} * * @param col name of the column to apply the value replacement @@ -324,10 +324,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * {{{ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight". - * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); + * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0)); * * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname". - * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed"); + * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed")); * }}} * * @param cols list of columns to apply the value replacement diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c1b32917415ae..628a82fd23c13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -283,7 +283,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * Loads JSON files and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -323,7 +323,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -525,7 +525,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 1732a8e08b73f..0259fffeab2db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand} @@ -286,7 +286,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partition = Map.empty[String, Option[String]], query = df.logicalPlan, overwrite = mode == SaveMode.Overwrite, - ifNotExists = false) + ifPartitionNotExists = false) } } @@ -372,8 +372,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { // Get all input data source or hive relations of the query. val srcRelations = df.logicalPlan.collect { case LogicalRelation(src: BaseRelation, _, _) => src - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) => - relation.tableMeta.identifier + case relation: HiveTableRelation => relation.tableMeta.identifier } val tableRelation = df.sparkSession.table(tableIdentWithDB).queryExecution.analyzed @@ -383,8 +382,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { throw new AnalysisException( s"Cannot overwrite table $tableName that is also being read from") // check hive table relation when overwrite mode - case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) - && srcRelations.contains(relation.tableMeta.identifier) => + case relation: HiveTableRelation + if srcRelations.contains(relation.tableMeta.identifier) => throw new AnalysisException( s"Cannot overwrite table $tableName that is also being read from") case _ => // OK diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 520663f624408..a775fb8ed4ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.{Date, Timestamp} -import java.util.TimeZone import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -36,7 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -132,7 +131,7 @@ private[sql] object Dataset { * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) - * .groupBy(department("name"), "gender") + * .groupBy(department("name"), people("gender")) * .agg(avg(people("salary")), max(people("age"))) * }}} * @@ -142,9 +141,9 @@ private[sql] object Dataset { * Dataset people = spark.read().parquet("..."); * Dataset department = spark.read().parquet("..."); * - * people.filter("age".gt(30)) - * .join(department, people.col("deptId").equalTo(department("id"))) - * .groupBy(department.col("name"), "gender") + * people.filter(people.col("age").gt(30)) + * .join(department, people.col("deptId").equalTo(department.col("id"))) + * .groupBy(department.col("name"), people.col("gender")) * .agg(avg(people.col("salary")), max(people.col("age"))); * }}} * @@ -247,7 +246,8 @@ class Dataset[T] private[sql]( val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) - lazy val timeZone = TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) + lazy val timeZone = + DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the @@ -484,7 +484,6 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming @@ -545,7 +544,6 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time * before which we assume no more late data is going to arrive. * @@ -569,7 +567,6 @@ class Dataset[T] private[sql]( * @group streaming * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. @@ -579,7 +576,8 @@ class Dataset[T] private[sql]( .getOrElse(throw new AnalysisException(s"Unable to parse time delay '$delayThreshold'")) require(parsedDelay.milliseconds >= 0 && parsedDelay.months >= 0, s"delay threshold ($delayThreshold) should not be negative.") - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan) + EliminateEventTimeWatermark( + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) } /** @@ -903,7 +901,7 @@ class Dataset[T] private[sql]( * @param condition Join expression. * @param joinType Type of join to perform. Default `inner`. Must be one of: * `inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`, - * `right`, `right_outer`, `left_semi`, `left_anti`. + * `right`, `right_outer`. * * @group typedrel * @since 1.6.0 @@ -920,6 +918,10 @@ class Dataset[T] private[sql]( JoinType(joinType), Some(condition.expr))).analyzed.asInstanceOf[Join] + if (joined.joinType == LeftSemi || joined.joinType == LeftAnti) { + throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql) + } + // For both join side, combine all outputs into a single column and alias it with "_1" or "_2", // to match the schema for the encoder of the join result. // Note that we do this before joining them, to enable the join operator to return null for one @@ -1026,7 +1028,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): Dataset[T] = { - sort((sortCol +: sortCols).map(apply) : _*) + sort((sortCol +: sortCols).map(Column(_)) : _*) } /** @@ -1073,6 +1075,22 @@ class Dataset[T] private[sql]( */ def apply(colName: String): Column = col(colName) + /** + * Specifies some hint on the current Dataset. As an example, the following code specifies + * that one of the plan can be broadcasted: + * + * {{{ + * df1.join(df2.hint("broadcast")) + * }}} + * + * @group basic + * @since 2.2.0 + */ + @scala.annotation.varargs + def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan { + UnresolvedHint(name, parameters, logicalPlan) + } + /** * Selects column based on the column name and return it as a [[Column]]. * @@ -1613,10 +1631,11 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * This is equivalent to `UNION ALL` in SQL. * - * To do a SQL-style set union (that does deduplication of elements), use this function followed - * by a [[distinct]]. + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). * * @group typedrel * @since 2.0.0 @@ -1626,10 +1645,11 @@ class Dataset[T] private[sql]( /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. - * This is equivalent to `UNION ALL` in SQL. * - * To do a SQL-style set union (that does deduplication of elements), use this function followed - * by a [[distinct]]. + * This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does + * deduplication of elements), use this function followed by a [[distinct]]. + * + * Also as standard in SQL, this function resolves columns by position (not by name). * * @group typedrel * @since 2.0.0 @@ -1726,15 +1746,23 @@ class Dataset[T] private[sql]( // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the - // ordering deterministic. - // MapType cannot be sorted. - val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) - .map(SortOrder(_, Ascending)), global = false, logicalPlan) + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + val plan = if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, logicalPlan) + } else { + // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism + cache() + logicalPlan + } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) }.toArray } @@ -2632,6 +2660,22 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = false, global = true) } + /** + * Creates or replaces a global temporary view using the given name. The lifetime of this + * temporary view is tied to this Spark application. + * + * Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application, + * i.e. it will be automatically dropped when the application terminates. It's tied to a system + * preserved database `_global_temp`, and we must use the qualified name to refer a global temp + * view, e.g. `SELECT * FROM _global_temp.view1`. + * + * @group basic + * @since 2.2.0 + */ + def createOrReplaceGlobalTempView(viewName: String): Unit = withPlan { + createTempViewCommand(viewName, replace = true, global = true) + } + private def createTempViewCommand( viewName: String, replace: Boolean, @@ -2670,13 +2714,11 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: * Interface for saving the content of the streaming Dataset out into external storage. * * @group basic * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def writeStream: DataStreamWriter[T] = { if (!isStreaming) { @@ -2735,7 +2777,7 @@ class Dataset[T] private[sql]( fsBasedRelation.inputFiles case fr: FileRelation => fr.inputFiles - case r: CatalogRelation if DDLUtils.isHiveTable(r.tableMeta) => + case r: HiveTableRelation => r.tableMeta.storage.locationUri.map(_.toString).toArray }.flatten files.toSet.toArray @@ -2778,7 +2820,7 @@ class Dataset[T] private[sql]( * Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with * an execution. */ - private[sql] def withNewExecutionId[U](body: => U): U = { + private def withNewExecutionId[U](body: => U): U = { SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala index 372ec262f5764..86e02e98c01f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * A class to consume data generated by a `StreamingQuery`. Typically this is used to send the * generated data to external systems. Each partition will use a new deserialized instance, so you * usually should do all the initialization (e.g. opening a connection or initiating a transaction) @@ -66,7 +65,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * }}} * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving abstract class ForeachWriter[T] extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cc2983987eb90..7fde6e9469e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -505,7 +505,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) /** - * :: Experimental :: * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. * {{{ * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") @@ -514,7 +513,6 @@ class SQLContext private[sql](val sparkSession: SparkSession) * * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def readStream: DataStreamReader = sparkSession.readStream diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 375df64d39734..17671ea8685b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -111,93 +111,60 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** * @since 1.6.1 - * @deprecated use [[newIntSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newLongSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newDoubleSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newFloatSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newByteSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newShortSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newBooleanSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newStringSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newProductSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ - implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() /** @since 2.2.0 */ - implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] = - ExpressionEncoder() + implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() // Arrays diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 95f3463dfe62b..62b277d3bff50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -32,13 +32,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.HadoopFileSelector import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -77,11 +78,12 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], - @transient private val parentSessionState: Option[SessionState]) + @transient private val parentSessionState: Option[SessionState], + @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None, None) + this(sc, None, None, new SparkSessionExtensions) } sparkContext.assertNotStopped() @@ -111,6 +113,12 @@ class SparkSession private( existingSharedState.getOrElse(new SharedState(sparkContext)) } + /** + * Initial options for session. This options are applied once when sessionState is created. + */ + @transient + private[sql] val initialSessionOptions = new scala.collection.mutable.HashMap[String, String] + /** * State isolated across sessions, including SQL configurations, temporary tables, registered * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. @@ -126,9 +134,11 @@ class SparkSession private( parentSessionState .map(_.clone(this)) .getOrElse { - SparkSession.instantiateSessionState( + val state = SparkSession.instantiateSessionState( SparkSession.sessionStateClassName(sparkContext.conf), self) + initialSessionOptions.foreach { case (k, v) => state.conf.setConfString(k, v) } + state } } @@ -219,7 +229,7 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState), parentSessionState = None) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions) } /** @@ -235,7 +245,7 @@ class SparkSession private( * implementation is Hive, this will initialize the metastore, which may take some time. */ private[sql] def cloneSession(): SparkSession = { - val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState)) + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions) result.sessionState // force copy of SessionState result } @@ -608,6 +618,18 @@ class SparkSession private( Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent)) } + def setTableNamePreprocessor(newTableNamePreprocessor: (String) => String): Unit = { + sharedState.externalCatalog.setTableNamePreprocessor(newTableNamePreprocessor) + } + + def setHadoopFileSelector(hadoopFileSelector: HadoopFileSelector): Unit = { + sharedState.externalCatalog.setHadoopFileSelector(hadoopFileSelector) + } + + def unsetHadoopFileSelector(): Unit = { + sharedState.externalCatalog.unsetHadoopFileSelector() + } + /* ----------------- * | Everything else | * ----------------- */ @@ -635,7 +657,6 @@ class SparkSession private( def read: DataFrameReader = new DataFrameReader(self) /** - * :: Experimental :: * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. * {{{ * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") @@ -644,7 +665,6 @@ class SparkSession private( * * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving def readStream: DataStreamReader = new DataStreamReader(self) @@ -754,6 +774,8 @@ object SparkSession { private[this] val options = new scala.collection.mutable.HashMap[String, String] + private[this] val extensions = new SparkSessionExtensions + private[this] var userSuppliedContext: Option[SparkContext] = None private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { @@ -847,6 +869,17 @@ object SparkSession { } } + /** + * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules, + * Optimizer rules, Planning Strategies or a customized parser. + * + * @since 2.2.0 + */ + def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + f(extensions) + this + } + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new * one based on the options set in this builder. @@ -903,8 +936,27 @@ object SparkSession { } sc } - session = new SparkSession(sparkContext) - options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } + + // Initialize extensions if the user has defined a configurator class. + val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + + session = new SparkSession(sparkContext, None, None, extensions) + options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } defaultSession.set(session) // Register a successfully instantiated context to the singleton. This should be at the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala new file mode 100644 index 0000000000000..f99c108161f94 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: Experimental :: + * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability + * regarding binary compatibility and source compatibility of methods here. + * + * This current provides the following extension points: + * - Analyzer Rules. + * - Check Analysis Rules + * - Optimizer Rules. + * - Planning Strategies. + * - Customized Parser. + * - (External) Catalog listeners. + * + * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for + * example: + * {{{ + * SparkSession.builder() + * .master("...") + * .conf("...", true) + * .withExtensions { extensions => + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectParser { (session, parser) => + * ... + * } + * } + * .getOrCreate() + * }}} + * + * Note that none of the injected builders should assume that the [[SparkSession]] is fully + * initialized and should not touch the session's internals (e.g. the SessionState). + */ +@DeveloperApi +@Experimental +@InterfaceStability.Unstable +class SparkSessionExtensions { + type RuleBuilder = SparkSession => Rule[LogicalPlan] + type CheckRuleBuilder = SparkSession => LogicalPlan => Unit + type StrategyBuilder = SparkSession => Strategy + type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + resolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed as part of the resolution phase of analysis. + */ + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + postHocResolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed after resolution. + */ + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder] + + /** + * Build the check analysis `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = { + checkRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will + * be executed after the analysis phase. A check analysis rule is used to detect problems with a + * LogicalPlan and should throw an exception when a problem is found. + */ + def injectCheckRule(builder: CheckRuleBuilder): Unit = { + checkRuleBuilders += builder + } + + private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + optimizerRules.map(_.apply(session)) + } + + /** + * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be + * executed during the operator optimization batch. An optimizer rule is used to improve the + * quality of an analyzed logical plan; these rules should never modify the result of the + * LogicalPlan. + */ + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + + private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { + plannerStrategyBuilders.map(_.apply(session)) + } + + /** + * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will + * be used to convert a `LogicalPlan` into a executable + * [[org.apache.spark.sql.execution.SparkPlan]]. + */ + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder] + + private[sql] def buildParser( + session: SparkSession, + initial: ParserInterface): ParserInterface = { + parserBuilders.foldLeft(initial) { (parser, builder) => + builder(session, parser) + } + } + + /** + * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session + * and an initial parser. The latter allows for a user to create a partial parser and to delegate + * to the underlying parser for completeness. If a user injects more parsers, then the parsers + * are stacked on top of each other. + */ + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index a57673334c10b..6accf1f75064c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -70,15 +70,31 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @param name the name of the UDAF. * @param udaf the UDAF needs to be registered. * @return the registered UDAF. + * + * @since 1.5.0 */ - def register( - name: String, - udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf) functionRegistry.registerFunction(name, builder) udaf } + /** + * Register a user-defined function (UDF), for a UDF that's already defined using the DataFrame + * API (i.e. of type UserDefinedFunction). + * + * @param name the name of the UDF. + * @param udf the UDF needs to be registered. + * @return the registered UDF. + * + * @since 2.2.0 + */ + def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { + def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr + functionRegistry.registerFunction(name, builder) + udf + } + // scalastyle:off line.size.limit /* register 0-22 were generated by this script diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 866fa98533218..6fb41b6425c4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -67,7 +67,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(SparkSession.getActiveSession.get.sparkContext.conf, text) + Utils.redact(SparkSession.getActiveSession.map(_.sparkContext.conf).orNull, text) } } @@ -519,8 +519,8 @@ case class FileSourceScanExec( relation, output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, - partitionFilters.map(QueryPlan.normalizeExprId(_, output)), - dataFilters.map(QueryPlan.normalizeExprId(_, output)), + QueryPlan.normalizePredicates(partitionFilters, output), + QueryPlan.normalizePredicates(dataFilters, output), None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index 458ac4ba3637c..01c9c65e5399d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -31,16 +31,16 @@ import org.apache.spark.storage.BlockManager import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} /** - * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined - * threshold of rows is reached. + * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array + * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which + * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is + * excessive memory consumption). Setting these threshold involves following trade-offs: * - * Setting spill threshold faces following trade-off: - * - * - If the spill threshold is too high, the in-memory array may occupy more memory than is - * available, resulting in OOM. - * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes. - * This may lead to a performance regression compared to the normal case of using an - * [[ArrayBuffer]] or [[Array]]. + * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory + * than is available, resulting in OOM. + * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to + * excessive disk writes. This may lead to a performance regression compared to the normal case + * of using an [[ArrayBuffer]] or [[Array]]. */ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskMemoryManager: TaskMemoryManager, @@ -49,9 +49,10 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskContext: TaskContext, initialSize: Int, pageSizeBytes: Long, + numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) extends Logging { - def this(numRowsSpillThreshold: Int) { + def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { this( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, @@ -59,11 +60,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( TaskContext.get(), 1024, SparkEnv.get.memoryManager.pageSizeBytes, + numRowsInMemoryBufferThreshold, numRowsSpillThreshold) } private val initialSizeOfInMemoryBuffer = - Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold) + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold) private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) @@ -102,11 +104,11 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( } def add(unsafeRow: UnsafeRow): Unit = { - if (numRows < numRowsSpillThreshold) { + if (numRows < numRowsInMemoryBufferThreshold) { inMemoryBuffer += unsafeRow.copy() } else { if (spillableArray == null) { - logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " + + logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " + s"${classOf[UnsafeExternalSorter].getName}") // We will not sort the rows, so prefixComparator and recordComparator are null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f87d05884b276..c35e5638e9273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow]) extends Iterator[InternalRow] { - lazy val results = func().toIterator + lazy val results: Iterator[InternalRow] = func().toIterator override def hasNext: Boolean = results.hasNext override def next(): InternalRow = results.next() } @@ -50,7 +50,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param generatorOutput the qualified output attributes of the generator of this node, which * constructed in analysis phase, and we can not change it, as the * parent node bound with it already. @@ -78,15 +78,15 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator = BindReferences.bindReference(generator, child.output) + lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - val rows = if (join) { - child.execute().mapPartitionsInternal { iter => - val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (index, iter) => + val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val rows = if (join) { val joinedRow = new JoinedRow - iter.flatMap { row => // we should always set the left (child output) joinedRow.withLeft(row) @@ -101,18 +101,21 @@ case class GenerateExec( // keep it the same as Hive does joinedRow.withRight(row) } + } else { + iter.flatMap { row => + val outputRows = boundGenerator.eval(row) + if (outer && outputRows.isEmpty) { + Seq(generatorNullRow) + } else { + outputRows + } + } ++ LazyIterator(boundGenerator.terminate) } - } else { - child.execute().mapPartitionsInternal { iter => - iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) - } - } - val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsWithIndexInternal { (index, iter) => + // Convert the rows to unsafe rows. val proj = UnsafeProjection.create(output, output) proj.initialize(index) - iter.map { r => + rows.map { r => numOutputRows += 1 proj(r) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 19c68c13262a5..514ad7018d8c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -28,12 +28,12 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class LocalTableScanExec( output: Seq[Attribute], - rows: Seq[InternalRow]) extends LeafExecNode { + @transient rows: Seq[InternalRow]) extends LeafExecNode { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - private lazy val unsafeRows: Array[InternalRow] = { + @transient private lazy val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 3c046ce494285..d59b3c6f0caf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{HiveTableRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -101,7 +101,7 @@ case class OptimizeMetadataOnlyQuery( val partitionData = fsRelation.location.listFiles(Nil, Nil) LocalRelation(partAttrs, partitionData.map(_.values)) - case relation: CatalogRelation => + case relation: HiveTableRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) val caseInsensitiveProperties = CaseInsensitiveMap(relation.tableMeta.storage.properties) @@ -137,7 +137,7 @@ case class OptimizeMetadataOnlyQuery( val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l) Some(AttributeSet(partAttrs), l) - case relation: CatalogRelation if relation.tableMeta.partitionColumnNames.nonEmpty => + case relation: HiveTableRelation if relation.tableMeta.partitionColumnNames.nonEmpty => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) Some(AttributeSet(partAttrs), relation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 8e8210e334a1d..2e05e5d65923c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.util.TimeZone import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -187,7 +186,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) case (t: Timestamp, TimestampType) => DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), - TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) + DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) case (other, tpe) if primitiveTypes.contains(tpe) => other.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index cadab37a449aa..c4ed96640eb19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -22,6 +22,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext +import org.codehaus.commons.compiler.CompileException +import org.codehaus.janino.JaninoRuntimeException + import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec @@ -353,9 +356,27 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } + private def genInterpretedPredicate( + expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = { + val str = expression.toString + val logMessage = if (str.length > 256) { + str.substring(0, 256 - 3) + "..." + } else { + str + } + logWarning(s"Codegen disabled for this expression:\n $logMessage") + InterpretedPredicate.create(expression, inputSchema) + } + protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = { - GeneratePredicate.generate(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e @ (_: JaninoRuntimeException | _: CompileException) + if sqlContext == null || sqlContext.conf.wholeStageFallback => + genInterpretedPredicate(expression, inputSchema) + } } protected def newOrdering( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6566502bd8a8a..4e718d609c921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,7 +36,7 @@ class SparkPlanner( experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( FileSourceStrategy :: - DataSourceStrategy :: + DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: JoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 20dacf88504f1..c2c52894860b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -52,7 +52,7 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser { /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { +class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { import org.apache.spark.sql.catalyst.parser.ParserUtils._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ca2f6dd7a84b2..843ce63161220 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -114,7 +114,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).isBroadcastable || + plan.stats(conf).hints.isBroadcastable.getOrElse(false) || (plan.stats(conf).sizeInBytes >= 0 && plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) } @@ -383,8 +383,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil case logical.FlatMapGroupsWithState( - f, key, value, grouping, data, output, _, _, _, _, child) => - execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil + f, key, value, grouping, data, output, _, _, _, timeout, child) => + execution.MapGroupsExec( + f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, @@ -432,7 +433,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil - case BroadcastHint(child) => planLater(child) :: Nil + case h: ResolvedHint => planLater(h.child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c1e1a631c677e..974315db584da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -489,13 +489,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { * Inserts an InputAdapter on top of those that do not support codegen. */ private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { - case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen => - // The children of SortMergeJoin should do codegen separately. - j.copy(left = InputAdapter(insertWholeStageCodegen(left)), - right = InputAdapter(insertWholeStageCodegen(right))) case p if !supportCodegen(p) => // collapse them recursively InputAdapter(insertWholeStageCodegen(p)) + case j @ SortMergeJoinExec(_, _, _, _, left, right) => + // The children of SortMergeJoin should do codegen separately. + j.copy(left = InputAdapter(insertWholeStageCodegen(left)), + right = InputAdapter(insertWholeStageCodegen(right))) case p => p.withNewChildren(p.children.map(insertInputAdapter)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 3a7fcf1fa9d89..6e47f9d611199 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering} import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator @@ -39,7 +40,8 @@ class ObjectAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputRows: Iterator[InternalRow], - fallbackCountThreshold: Int) + fallbackCountThreshold: Int, + numOutputRows: SQLMetric) extends AggregationIterator( groupingExpressions, originalInputAttributes, @@ -83,7 +85,9 @@ class ObjectAggregationIterator( override final def next(): UnsafeRow = { val entry = aggBufferIterator.next() - generateOutput(entry.groupingKey, entry.aggregationBuffer) + val res = generateOutput(entry.groupingKey, entry.aggregationBuffer) + numOutputRows += 1 + res } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fcb7ec9a6411..b53521b1b6ba2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -117,7 +117,8 @@ case class ObjectHashAggregateExec( newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, - fallbackCountThreshold) + fallbackCountThreshold, + numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { numOutputRows += 1 Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 44278e37c5276..bd7a5c5d914c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext} -import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} +import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} @@ -331,29 +331,32 @@ case class SampleExec( case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { - def start: Long = range.start - def step: Long = range.step - def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) - def numElements: BigInt = range.numElements + val start: Long = range.start + val end: Long = range.end + val step: Long = range.step + val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + val numElements: BigInt = range.numElements override val output: Seq[Attribute] = range.output override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override lazy val canonicalized: SparkPlan = { RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) } override def inputRDDs(): Seq[RDD[InternalRow]] = { - sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) - .map(i => InternalRow(i)) :: Nil + val rdd = if (start == end || (start < end ^ 0 < step)) { + new EmptyRDD[InternalRow](sqlContext.sparkContext) + } else { + sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i)) + } + rdd :: Nil } protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val numGenerated = metricTerm(ctx, "numGeneratedRows") val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") @@ -463,9 +466,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $number = $batchEnd; | } | - | if ($taskContext.isInterrupted()) { - | throw new TaskKilledException(); - | } + | $taskContext.killTaskIfInterrupted(); | | long $nextBatchTodo; | if ($numElementsTodo > ${batchSize}L) { @@ -540,7 +541,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } } - override def simpleString: String = range.simpleString + override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 214e8d309de11..7063b08f7c644 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -42,7 +42,9 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { - val attrMap = AttributeMap(relation.child.output.zip(output)) + // attributes can be pruned so using relation's output. + // E.g., relation.output is [id, item] but this scan's output can be [item] only. + val attrMap = AttributeMap(relation.child.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d2ea0cdf61aa6..bf7c22761dc04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.net.URI + import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -45,10 +47,10 @@ case class AnalyzeTableCommand( } val newTotalSize = AnalyzeTableCommand.calculateTotalSize(sessionState, tableMeta) - val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(0L) + val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(-1L) val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L) var newStats: Option[CatalogStatistics] = None - if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + if (newTotalSize >= 0 && newTotalSize != oldTotalSize) { newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize)) } // We only set rowCount when noscan is false, because otherwise: @@ -81,6 +83,21 @@ case class AnalyzeTableCommand( object AnalyzeTableCommand extends Logging { def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map(p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + ).sum + } + } + + private def calculateLocationSize( + sessionState: SessionState, + tableId: TableIdentifier, + locationUri: Option[URI]): Long = { // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -91,13 +108,13 @@ object AnalyzeTableCommand extends Logging { // countFileSize to count the table size. val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - def calculateTableSize(fs: FileSystem, path: Path): Long = { + def calculateLocationSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) + calculateLocationSize(fs, status.getPath) } else { 0L } @@ -109,16 +126,16 @@ object AnalyzeTableCommand extends Logging { size } - catalogTable.storage.locationUri.map { p => + locationUri.map { p => val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateTableSize(fs, path) + calculateLocationSize(fs, path) } catch { case NonFatal(e) => logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + s"Failed to get the size of table ${tableId.table} in the " + + s"database ${tableId.database} because of ${e.toString}", e) 0L } }.getOrElse(0L) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 55540563ef911..b1eaecb98defe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -21,7 +21,6 @@ import java.util.Locale import scala.collection.{GenMap, GenSeq} import scala.collection.parallel.ForkJoinTaskSupport -import scala.concurrent.forkjoin.ForkJoinPool import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -582,8 +581,15 @@ case class AlterTableRecoverPartitionsCommand( val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt val hadoopConf = spark.sparkContext.hadoopConfiguration val pathFilter = getPathFilter(hadoopConf) - val partitionSpecsAndLocs = scanPartitions(spark, fs, pathFilter, root, Map(), - table.partitionColumnNames, threshold, spark.sessionState.conf.resolver) + + val evalPool = ThreadUtils.newForkJoinPool("AlterTableRecoverPartitionsCommand", 8) + val partitionSpecsAndLocs: Seq[(TablePartitionSpec, Path)] = + try { + scanPartitions(spark, fs, pathFilter, root, Map(), table.partitionColumnNames, threshold, + spark.sessionState.conf.resolver, new ForkJoinTaskSupport(evalPool)).seq + } finally { + evalPool.shutdown() + } val total = partitionSpecsAndLocs.length logInfo(s"Found $total partitions in $root") @@ -604,8 +610,6 @@ case class AlterTableRecoverPartitionsCommand( Seq.empty[Row] } - @transient private lazy val evalTaskSupport = new ForkJoinTaskSupport(new ForkJoinPool(8)) - private def scanPartitions( spark: SparkSession, fs: FileSystem, @@ -614,7 +618,8 @@ case class AlterTableRecoverPartitionsCommand( spec: TablePartitionSpec, partitionNames: Seq[String], threshold: Int, - resolver: Resolver): GenSeq[(TablePartitionSpec, Path)] = { + resolver: Resolver, + evalTaskSupport: ForkJoinTaskSupport): GenSeq[(TablePartitionSpec, Path)] = { if (partitionNames.isEmpty) { return Seq(spec -> path) } @@ -638,7 +643,7 @@ case class AlterTableRecoverPartitionsCommand( val value = ExternalCatalogUtils.unescapePathName(ps(1)) if (resolver(columnName, partitionNames.head)) { scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(partitionNames.head -> value), - partitionNames.drop(1), threshold, resolver) + partitionNames.drop(1), threshold, resolver, evalTaskSupport) } else { logWarning( s"expected partition column ${partitionNames.head}, but got ${ps(0)}, ignoring it") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index ebf03e1bf8869..63486382a405c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -522,15 +522,15 @@ case class DescribeTableCommand( throw new AnalysisException( s"DESC PARTITION is not allowed on a temporary view: ${table.identifier}") } - describeSchema(catalog.lookupRelation(table).schema, result) + describeSchema(catalog.lookupRelation(table).schema, result, header = false) } else { val metadata = catalog.getTableMetadata(table) if (metadata.schema.isEmpty) { // In older version(prior to 2.1) of Spark, the table schema can be empty and should be // inferred at runtime. We should still support it. - describeSchema(sparkSession.table(metadata.identifier).schema, result) + describeSchema(sparkSession.table(metadata.identifier).schema, result, header = false) } else { - describeSchema(metadata.schema, result) + describeSchema(metadata.schema, result, header = false) } describePartitionInfo(metadata, result) @@ -550,7 +550,7 @@ case class DescribeTableCommand( private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { if (table.partitionColumnNames.nonEmpty) { append(buffer, "# Partition Information", "", "") - describeSchema(table.partitionSchema, buffer) + describeSchema(table.partitionSchema, buffer, header = true) } } @@ -601,8 +601,13 @@ case class DescribeTableCommand( table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } - private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + private def describeSchema( + schema: StructType, + buffer: ArrayBuffer[Row], + header: Boolean): Unit = { + if (header) { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + } schema.foreach { column => append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 00f0acab21aa2..3518ee581c5fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -159,7 +159,9 @@ case class CreateViewCommand( checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) + // Nothing we need to retain from the old view, so just drop and create a new one + catalog.dropTable(viewIdent, ignoreIfNotExists = false, purge = false) + catalog.createTable(prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) } else { // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already // exists. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f3b209deaae5c..a13bb2476ceab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -39,6 +40,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.Utils /** @@ -122,7 +124,7 @@ case class DataSource( val hdfsPath = new Path(path) val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualified) + SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) }.toArray new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) } @@ -181,6 +183,11 @@ case class DataSource( throw new AnalysisException( s"Unable to infer schema for $format. It must be specified manually.") } + + SchemaUtils.checkColumnNameDuplication( + (dataSchema ++ partitionSchema).map(_.name), "in the data schema and the partition schema", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + (dataSchema, partitionSchema) } @@ -339,22 +346,8 @@ case class DataSource( case (format: FileFormat, _) => val allPaths = caseInsensitiveOptions.get("path") ++ paths val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified) - - if (globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") - } - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") - } - globPath - }.toArray + val globbedPaths = allPaths.flatMap( + DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) @@ -408,16 +401,6 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) - // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does - // not need to have the query as child, to avoid to analyze an optimized query, - // because InsertIntoHadoopFsRelationCommand will be optimized first. - val partitionAttributes = partitionColumns.map { name => - val plan = data.logicalPlan - plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") - }.asInstanceOf[Attribute] - } val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _) => t.location @@ -430,7 +413,8 @@ case class DataSource( InsertIntoHadoopFsRelationCommand( outputPath = outputPath, staticPartitions = Map.empty, - partitionColumns = partitionAttributes, + ifPartitionNotExists = false, + partitionColumns = partitionColumns, bucketSpec = bucketSpec, fileFormat = format, options = options, @@ -481,7 +465,7 @@ case class DataSource( } } -object DataSource { +object DataSource extends Logging { /** A map to maintain backward compatibility in case we move data sources around. */ private val backwardCompatibilityMap: Map[String, String] = { @@ -570,10 +554,19 @@ object DataSource { // there is exactly one registered alias head.getClass case sources => - // There are multiple registered aliases for the input - sys.error(s"Multiple sources found for $provider1 " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name.") + // There are multiple registered aliases for the input. If there is single datasource + // that has "org.apache.spark" package in the prefix, we use it considering it is an + // internal datasource within Spark. + val sourceNames = sources.map(_.getClass.getName) + val internalSources = sources.filter(_.getClass.getName.startsWith("org.apache.spark")) + if (internalSources.size == 1) { + logWarning(s"Multiple sources found for $provider1 (${sourceNames.mkString(", ")}), " + + s"defaulting to the internal datasource (${internalSources.head.getClass.getName}).") + internalSources.head.getClass + } else { + throw new AnalysisException(s"Multiple sources found for $provider1 " + + s"(${sourceNames.mkString(", ")}), please specify the fully qualified class name.") + } } } catch { case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] => @@ -600,4 +593,28 @@ object DataSource { CatalogStorageFormat.empty.copy( locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } + + /** + * If `path` is a file pattern, return all the files that match it. Otherwise, return itself. + * If `checkFilesExist` is `true`, also check the file existence. + */ + private def checkAndGlobPathIfNecessary( + hadoopConf: Configuration, + path: String, + checkFilesExist: Boolean): Seq[Path] = { + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d83d512e702d..04ee081a0f9ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,10 +24,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { def resolver: Resolver = conf.resolver @@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { val potentialSpecs = staticPartitions.filter { case (partKey, partValue) => resolver(field.name, partKey) } - if (potentialSpecs.size == 0) { + if (potentialSpecs.isEmpty) { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - Some(Alias(Cast(Literal(partValue), field.dataType), field.name)()) + Some(Alias(cast(Literal(partValue), field.dataType), field.name)()) } else { throw new AnalysisException( s"Partition column ${field.name} have multiple values specified, " + @@ -142,8 +142,8 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { parts, query, overwrite, false) if parts.isEmpty => InsertIntoDataSourceCommand(l, query, overwrite) - case InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false) => + case i @ InsertIntoTable( + l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, _) => // If the InsertIntoTable command is for a partitioned HadoopFsRelation and // the user has specified static partitions, we add a Project operator on top of the query // to include those constant column values in the query result. @@ -188,14 +188,13 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - val partitionSchema = actualQuery.resolve( - t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get } InsertIntoHadoopFsRelationCommand( outputPath, staticPartitions, - partitionSchema, + i.ifPartitionNotExists, + partitionColumns = t.partitionSchema.map(_.name), t.bucketSpec, t.fileFormat, t.options, @@ -208,15 +207,16 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { /** - * Replaces [[CatalogRelation]] with data source table if its table provider is not hive. + * Replaces [[UnresolvedCatalogRelation]] with concrete relation logical plans. + * + * TODO: we should remove the special handling for hive tables after completely making hive as a + * data source. */ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { - private def readDataSourceTable(r: CatalogRelation): LogicalPlan = { - val table = r.tableMeta + private def readDataSourceTable(table: CatalogTable): LogicalPlan = { val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) - val cache = sparkSession.sessionState.catalog.tableRelationCache - - val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { + val catalog = sparkSession.sessionState.catalog + catalog.getCachedPlan(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) val dataSource = @@ -233,24 +233,30 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table) } - }).asInstanceOf[LogicalRelation] + }) + } - if (r.output.isEmpty) { - // It's possible that the table schema is empty and need to be inferred at runtime. For this - // case, we don't need to change the output of the cached plan. - plan - } else { - plan.copy(output = r.output) - } + private def readHiveTable(table: CatalogTable): LogicalPlan = { + HiveTableRelation( + table, + // Hive table columns are always nullable. + table.dataSchema.asNullable.toAttributes, + table.partitionSchema.asNullable.toAttributes) } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(r: CatalogRelation, _, _, _, _) - if DDLUtils.isDatasourceTable(r.tableMeta) => - i.copy(table = readDataSourceTable(r)) + case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) + if DDLUtils.isDatasourceTable(tableMeta) => + i.copy(table = readDataSourceTable(tableMeta)) - case r: CatalogRelation if DDLUtils.isDatasourceTable(r.tableMeta) => - readDataSourceTable(r) + case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) => + i.copy(table = readHiveTable(tableMeta)) + + case UnresolvedCatalogRelation(tableMeta) if DDLUtils.isDatasourceTable(tableMeta) => + readDataSourceTable(tableMeta) + + case UnresolvedCatalogRelation(tableMeta) => + readHiveTable(tableMeta) } } @@ -258,7 +264,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] /** * A Strategy for planning scans over data sources defined using the sources API. */ -object DataSourceStrategy extends Strategy with Logging { +case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { + import DataSourceStrategy._ + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -298,7 +306,7 @@ object DataSourceStrategy extends Strategy with Logging { // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, bucketColumn :: Nil) @@ -436,7 +444,9 @@ object DataSourceStrategy extends Strategy with Logging { private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { toCatalystRDD(relation, relation.output, rdd) } +} +object DataSourceStrategy { /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * @@ -527,8 +537,8 @@ object DataSourceStrategy extends Strategy with Logging { * all [[Filter]]s that are completely filtered at the DataSource. */ protected[sql] def selectFilters( - relation: BaseRelation, - predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are // called `predicate`s, while all data source filters of type `sources.Filter` are simply called diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 4ec09bff429c5..2c31d2a84c258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging { committer: FileCommitProtocol, outputSpec: OutputSpec, hadoopConf: Configuration, - partitionColumns: Seq[Attribute], + partitionColumnNames: Seq[String], bucketSpec: Option[BucketSpec], refreshFunction: (Seq[TablePartitionSpec]) => Unit, options: Map[String, String]): Unit = { @@ -111,9 +111,18 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = queryExecution.logical.output + val allColumns = queryExecution.executedPlan.output + // Get the actual partition columns as attributes after matching them by name with + // the given columns names. + val partitionColumns = partitionColumnNames.map { col => + val nameEquality = sparkSession.sessionState.conf.resolver + allColumns.find(f => nameEquality(f.name, col)).getOrElse { + throw new RuntimeException( + s"Partition column $col not found in schema ${queryExecution.executedPlan.schema}") + } + } val partitionSet = AttributeSet(partitionColumns) - val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val dataColumns = allColumns.filterNot(partitionSet.contains) val bucketIdExpression = bucketSpec.map { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 9897ab73b0da8..91e31650617ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.execution.streaming.FileStreamSink import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -36,20 +37,28 @@ import org.apache.spark.util.SerializableConfiguration * A [[FileIndex]] that generates the list of files to process by recursively listing all the * files present in `paths`. * - * @param rootPaths the list of root table paths to scan + * @param rootPathsSpecified the list of root table paths to scan (some of which might be + * filtered out later) * @param parameters as set of options to control discovery * @param partitionSchema an optional partition schema that will be use to provide types for the * discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, - override val rootPaths: Seq[Path], + rootPathsSpecified: Seq[Path], parameters: Map[String, String], partitionSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( sparkSession, parameters, partitionSchema, fileStatusCache) { + // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) + // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain + // such streaming metadata dir or files, e.g. when after globbing "basePath/*" where "basePath" + // is the output of a streaming query. + override val rootPaths = + rootPathsSpecified.filterNot(FileStreamSink.ancestorIsMetadataDirectory(_, hadoopConf)) + @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ @volatile private var cachedPartitionSpec: PartitionSpec = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 19b51d4d9530a..ab35fdcbc1f25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -37,11 +37,14 @@ import org.apache.spark.sql.execution.command._ * overwrites: when the spec is empty, all partitions are overwritten. * When it covers a prefix of the partition keys, only partitions matching * the prefix are overwritten. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, staticPartitions: TablePartitionSpec, - partitionColumns: Seq[Attribute], + ifPartitionNotExists: Boolean, + partitionColumns: Seq[String], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, options: Map[String, String], @@ -61,8 +64,8 @@ case class InsertIntoHadoopFsRelationCommand( val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { case (x, ys) if ys.length > 1 => "\"" + x + "\"" }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to file.") + throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " + + "cannot save to file.") } val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) @@ -76,11 +79,12 @@ case class InsertIntoHadoopFsRelationCommand( var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty // When partitions are tracked by the catalog, compute all custom partition locations that // may be relevant to the insertion job. if (partitionsTrackedByCatalog) { - val matchingPartitions = sparkSession.sessionState.catalog.listPartitions( + matchingPartitions = sparkSession.sessionState.catalog.listPartitions( catalogTable.get.identifier, Some(staticPartitions)) initialMatchingPartitions = matchingPartitions.map(_.spec) customPartitionLocations = getCustomPartitionLocations( @@ -101,8 +105,12 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) - true + if (ifPartitionNotExists && matchingPartitions.nonEmpty) { + false + } else { + deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) + true + } case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true case (SaveMode.Ignore, exists) => @@ -142,7 +150,7 @@ case class InsertIntoHadoopFsRelationCommand( outputSpec = FileFormatWriter.OutputSpec( qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, - partitionColumns = partitionColumns, + partitionColumnNames = partitionColumns, bucketSpec = bucketSpec, refreshFunction = refreshPartitionsCallback, options = options) @@ -168,10 +176,10 @@ case class InsertIntoHadoopFsRelationCommand( customPartitionLocations: Map[TablePartitionSpec, String], committer: FileCommitProtocol): Unit = { val staticPartitionPrefix = if (staticPartitions.nonEmpty) { - "/" + partitionColumns.flatMap { p => - staticPartitions.get(p.name) match { + "/" + partitionColumns.flatMap { col => + staticPartitions.get(col) match { case Some(value) => - Some(escapePathName(p.name) + "=" + escapePathName(value)) + Some(escapePathName(col) + "=" + escapePathName(value)) case None => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index ffd7f6c750f85..6b6f6388d54e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -177,7 +177,7 @@ abstract class PartitioningAwareFileIndex( }) val selected = partitions.filter { - case PartitionPath(values, _) => boundPredicate(values) + case PartitionPath(values, _) => boundPredicate.eval(values) } logInfo { val total = partitions.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index c3583209efc56..6f7438192dfe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -94,7 +94,7 @@ object PartitioningUtils { typeInference: Boolean, basePaths: Set[Path], timeZoneId: String): PartitionSpec = { - parsePartitions(paths, typeInference, basePaths, TimeZone.getTimeZone(timeZoneId)) + parsePartitions(paths, typeInference, basePaths, DateTimeUtils.getTimeZone(timeZoneId)) } private[datasources] def parsePartitions( @@ -138,7 +138,7 @@ object PartitioningUtils { "root directory of the table. If there are multiple root directories, " + "please load them separately and then union them.") - val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues, timeZone) // Creates the StructType which represents the partition columns. val fields = { @@ -243,7 +243,7 @@ object PartitioningUtils { if (equalSignIndex == -1) { None } else { - val columnName = columnSpec.take(equalSignIndex) + val columnName = unescapePathName(columnSpec.take(equalSignIndex)) assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") val rawColumnValue = columnSpec.drop(equalSignIndex + 1) @@ -322,7 +322,8 @@ object PartitioningUtils { * }}} */ def resolvePartitions( - pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = { + pathsWithPartitionValues: Seq[(Path, PartitionValues)], + timeZone: TimeZone): Seq[PartitionValues] = { if (pathsWithPartitionValues.isEmpty) { Seq.empty } else { @@ -337,7 +338,7 @@ object PartitioningUtils { val values = pathsWithPartitionValues.map(_._2) val columnCount = values.head.columnNames.size val resolvedValues = (0 until columnCount).map { i => - resolveTypeConflicts(values.map(_.literals(i))) + resolveTypeConflicts(values.map(_.literals(i)), timeZone) } // Fills resolved literals back to each partition @@ -474,7 +475,7 @@ object PartitioningUtils { * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" * types. */ - private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = { + private def resolveTypeConflicts(literals: Seq[Literal], timeZone: TimeZone): Seq[Literal] = { val desiredType = { val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) // Falls back to string if all values of this column are null or empty string @@ -482,7 +483,7 @@ object PartitioningUtils { } literals.map { case l @ Literal(_, dataType) => - Literal.create(Cast(l, desiredType).eval(), desiredType) + Literal.create(Cast(l, desiredType, Some(timeZone.getID)).eval(), desiredType) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 905b8683e10bd..f5df1848a38c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -59,8 +60,11 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(sparkSession) - val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation) - + // Change table stats based on the sizeInBytes of pruned files + val withStats = logicalRelation.catalogTable.map(_.copy( + stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes))))) + val prunedLogicalRelation = logicalRelation.copy( + relation = prunedFsRelation, catalogTable = withStats) // Keep partition-pruning predicates so that they are visible in physical planning val filterExpression = filters.reduceLeft(And) val filter = Filter(filterExpression, prunedLogicalRelation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 83bdf6fe224be..2de58384f9834 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -111,8 +111,8 @@ abstract class CSVDataSource extends Serializable { object CSVDataSource { def apply(options: CSVOptions): CSVDataSource = { - if (options.wholeFile) { - WholeFileCSVDataSource + if (options.multiLine) { + MultiLineCSVDataSource } else { TextInputCSVDataSource } @@ -196,7 +196,7 @@ object TextInputCSVDataSource extends CSVDataSource { } } -object WholeFileCSVDataSource extends CSVDataSource { +object MultiLineCSVDataSource extends CSVDataSource { override val isSplitable: Boolean = false override def readFile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 62e4c6e4b4ea0..a13a5a34b4a84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -117,7 +117,7 @@ class CSVOptions( name.map(CompressionCodecs.getCodecClassName) } - val timeZone: TimeZone = TimeZone.getTimeZone( + val timeZone: TimeZone = DateTimeUtils.getTimeZone( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. @@ -128,7 +128,7 @@ class CSVOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, Locale.US) - val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 591096d5efd22..96a8a51da18e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -97,10 +97,13 @@ class JDBCOptions( val lowerBound = parameters.get(JDBC_LOWER_BOUND).map(_.toLong) // the upper bound of the partition column val upperBound = parameters.get(JDBC_UPPER_BOUND).map(_.toLong) - require(partitionColumn.isEmpty || - (lowerBound.isDefined && upperBound.isDefined && numPartitions.isDefined), - s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," + - s" and '$JDBC_NUM_PARTITIONS' are required.") + // numPartitions is also used for data source writing + require((partitionColumn.isEmpty && lowerBound.isEmpty && upperBound.isEmpty) || + (partitionColumn.isDefined && lowerBound.isDefined && upperBound.isDefined && + numPartitions.isDefined), + s"When reading JDBC data sources, users need to specify all or none for the following " + + s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " + + s"and '$JDBC_NUM_PARTITIONS'") val fetchSize = { val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt require(size >= 0, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 2bdc43254133e..7097069b92b78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -286,7 +286,7 @@ private[jdbc] class JDBCRDD( conn = getConnection() val dialect = JdbcDialects.get(url) import scala.collection.JavaConverters._ - dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap) + dialect.beforeFetch(conn, options.asProperties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8b45dba04d29e..272cb4a82641e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -64,7 +64,8 @@ private[sql] object JDBCRelation extends Logging { s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") val numPartitions = - if ((upperBound - lowerBound) >= partitioning.numPartitions) { + if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */ + (upperBound - lowerBound) < 0) { partitioning.numPartitions } else { logWarning("The number of partitions is reduced because the specified number of " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 74dcfb06f5c2b..37e7bb0a59bb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -29,6 +29,8 @@ class JdbcRelationProvider extends CreatableRelationProvider override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { + import JDBCOptions._ + val jdbcOptions = new JDBCOptions(parameters) val partitionColumn = jdbcOptions.partitionColumn val lowerBound = jdbcOptions.lowerBound @@ -36,10 +38,13 @@ class JdbcRelationProvider extends CreatableRelationProvider val numPartitions = jdbcOptions.numPartitions val partitionInfo = if (partitionColumn.isEmpty) { - assert(lowerBound.isEmpty && upperBound.isEmpty) + assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " + + s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") null } else { - assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty) + assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, + s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + + s"'$JDBC_NUM_PARTITIONS' are also required") JDBCPartitioningInfo( partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 5fc3c2753b6cf..0183805d56257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -652,8 +652,17 @@ object JdbcUtils extends Logging { case e: SQLException => val cause = e.getNextException if (cause != null && e.getCause != cause) { + // If there is no cause already, set 'next exception' as cause. If cause is null, + // it *may* be because no cause was set yet if (e.getCause == null) { - e.initCause(cause) + try { + e.initCause(cause) + } catch { + // Or it may be null because the cause *was* explicitly initialized, to *null*, + // in which case this fails. There is no other way to detect it. + // addSuppressed in this case as well. + case _: IllegalStateException => e.addSuppressed(cause) + } } else { e.addSuppressed(cause) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 4f2963da9ace9..5a92a71d19e78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -86,8 +86,8 @@ abstract class JsonDataSource extends Serializable { object JsonDataSource { def apply(options: JSONOptions): JsonDataSource = { - if (options.wholeFile) { - WholeFileJsonDataSource + if (options.multiLine) { + MultiLineJsonDataSource } else { TextInputJsonDataSource } @@ -147,7 +147,7 @@ object TextInputJsonDataSource extends JsonDataSource { } } -object WholeFileJsonDataSource extends JsonDataSource { +object MultiLineJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2f3a2c62b912c..87fbf8b1bc9c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} class ParquetFileFormat extends FileFormat @@ -479,24 +479,29 @@ object ParquetFileFormat extends Logging { partFiles: Seq[FileStatus], ignoreCorruptFiles: Boolean): Seq[Footer] = { val parFiles = partFiles.par - parFiles.tasksupport = new ForkJoinTaskSupport(new ForkJoinPool(8)) - parFiles.flatMap { currentFile => - try { - // Skips row group information since we only need the schema. - // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, - // when it can't read the footer. - Some(new Footer(currentFile.getPath(), - ParquetFileReader.readFooter( - conf, currentFile, SKIP_ROW_GROUPS))) - } catch { case e: RuntimeException => - if (ignoreCorruptFiles) { - logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) - None - } else { - throw new IOException(s"Could not read footer for file: $currentFile", e) + val pool = ThreadUtils.newForkJoinPool("readingParquetFooters", 8) + parFiles.tasksupport = new ForkJoinTaskSupport(pool) + try { + parFiles.flatMap { currentFile => + try { + // Skips row group information since we only need the schema. + // ParquetFileReader.readFooter throws RuntimeException, instead of IOException, + // when it can't read the footer. + Some(new Footer(currentFile.getPath(), + ParquetFileReader.readFooter( + conf, currentFile, SKIP_ROW_GROUPS))) + } catch { case e: RuntimeException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $currentFile", e) + None + } else { + throw new IOException(s"Could not read footer for file: $currentFile", e) + } } - } - }.seq + }.seq + } finally { + pool.shutdown() + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index a6a6cef5861f3..763841efbd9f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -166,7 +166,14 @@ private[parquet] object ParquetFilters { * Converts data sources filters to Parquet filter predicates. */ def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { - val dataTypeOf = getFieldMap(schema) + val nameToType = getFieldMap(schema) + + // Parquet does not allow dots in the column name because dots are used as a column path + // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates + // with missing columns. The incorrect results could be got from Parquet when we push down + // filters for the column having dots in the names. Thus, we do not push down such filters. + // See SPARK-20364. + def canMakeFilterOn(name: String): Boolean = nameToType.contains(name) && !name.contains(".") // NOTE: // @@ -184,30 +191,30 @@ private[parquet] object ParquetFilters { // Probably I missed something and obviously this should be changed. predicate match { - case sources.IsNull(name) if dataTypeOf.contains(name) => - makeEq.lift(dataTypeOf(name)).map(_(name, null)) - case sources.IsNotNull(name) if dataTypeOf.contains(name) => - makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) - - case sources.EqualTo(name, value) if dataTypeOf.contains(name) => - makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualTo(name, value)) if dataTypeOf.contains(name) => - makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - - case sources.EqualNullSafe(name, value) if dataTypeOf.contains(name) => - makeEq.lift(dataTypeOf(name)).map(_(name, value)) - case sources.Not(sources.EqualNullSafe(name, value)) if dataTypeOf.contains(name) => - makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) - - case sources.LessThan(name, value) if dataTypeOf.contains(name) => - makeLt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.LessThanOrEqual(name, value) if dataTypeOf.contains(name) => - makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) - - case sources.GreaterThan(name, value) if dataTypeOf.contains(name) => - makeGt.lift(dataTypeOf(name)).map(_(name, value)) - case sources.GreaterThanOrEqual(name, value) if dataTypeOf.contains(name) => - makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.IsNull(name) if canMakeFilterOn(name) => + makeEq.lift(nameToType(name)).map(_(name, null)) + case sources.IsNotNull(name) if canMakeFilterOn(name) => + makeNotEq.lift(nameToType(name)).map(_(name, null)) + + case sources.EqualTo(name, value) if canMakeFilterOn(name) => + makeEq.lift(nameToType(name)).map(_(name, value)) + case sources.Not(sources.EqualTo(name, value)) if canMakeFilterOn(name) => + makeNotEq.lift(nameToType(name)).map(_(name, value)) + + case sources.EqualNullSafe(name, value) if canMakeFilterOn(name) => + makeEq.lift(nameToType(name)).map(_(name, value)) + case sources.Not(sources.EqualNullSafe(name, value)) if canMakeFilterOn(name) => + makeNotEq.lift(nameToType(name)).map(_(name, value)) + + case sources.LessThan(name, value) if canMakeFilterOn(name) => + makeLt.lift(nameToType(name)).map(_(name, value)) + case sources.LessThanOrEqual(name, value) if canMakeFilterOn(name) => + makeLtEq.lift(nameToType(name)).map(_(name, value)) + + case sources.GreaterThan(name, value) if canMakeFilterOn(name) => + makeGt.lift(nameToType(name)).map(_(name, value)) + case sources.GreaterThanOrEqual(name, value) if canMakeFilterOn(name) => + makeGtEq.lift(nameToType(name)).map(_(name, value)) case sources.And(lhs, rhs) => // At here, it is not safe to just convert one side if we do not understand the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 38b0e33937f3c..9fcca1610ff05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.types._ * of this option is propagated to this class by the `init()` method and its Hadoop configuration * argument. */ -private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { +class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { // A `ValueWriter` is responsible for writing a field of an `InternalRow` to the record consumer. // Here we are using `SpecializedGetters` rather than `InternalRow` so that we can directly access // data in `ArrayData` without the help of `SpecificMutableRow`. @@ -58,7 +58,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit private var schema: StructType = _ // `ValueWriter`s for all fields of the schema - private var rootFieldWriters: Seq[ValueWriter] = _ + private var rootFieldWriters: Array[ValueWriter] = _ // The Parquet `RecordConsumer` to which all `InternalRow`s are written private var recordConsumer: RecordConsumer = _ @@ -90,7 +90,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit } - this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) + this.rootFieldWriters = schema.map(_.dataType).map(makeWriter).toArray[ValueWriter] val messageType = new ParquetSchemaConverter(configuration).convert(schema) val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava @@ -116,7 +116,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit } private def writeFields( - row: InternalRow, schema: StructType, fieldWriters: Seq[ValueWriter]): Unit = { + row: InternalRow, schema: StructType, fieldWriters: Array[ValueWriter]): Unit = { var i = 0 while (i < row.numFields) { if (!row.isNullAt(i)) { @@ -192,7 +192,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit makeDecimalWriter(precision, scale) case t: StructType => - val fieldWriters = t.map(_.dataType).map(makeWriter) + val fieldWriters = t.map(_.dataType).map(makeWriter).toArray[ValueWriter] (row: SpecializedGetters, ordinal: Int) => consumeGroup { writeFields(row.getStruct(ordinal, t.length), t, fieldWriters) @@ -439,7 +439,7 @@ private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] wit } } -private[parquet] object ParquetWriteSupport { +object ParquetWriteSupport { val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" def setSchema(schema: StructType, configuration: Configuration): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7abf2ae5166b5..9647f2c0edccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -127,11 +127,11 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi val resolver = sparkSession.sessionState.conf.resolver val tableCols = existingTable.schema.map(_.name) - // As we are inserting into an existing table, we should respect the existing schema and - // adjust the column order of the given dataframe according to it, or throw exception - // if the column names do not match. + // As we are inserting into an existing table, we should respect the existing schema, preserve + // the case and adjust the column order of the given DataFrame according to it, or throw + // an exception if the column names do not match. val adjustedColumns = tableCols.map { col => - query.resolve(Seq(col), resolver).getOrElse { + query.resolve(Seq(col), resolver).map(Alias(_, col)()).getOrElse { val inputColumns = query.schema.map(_.name).mkString(", ") throw new AnalysisException( s"cannot resolve '$col' given input columns: [$inputColumns]") @@ -168,15 +168,9 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi """.stripMargin) } - val newQuery = if (adjustedColumns != query.output) { - Project(adjustedColumns, query) - } else { - query - } - c.copy( tableDesc = existingTable, - query = Some(newQuery)) + query = Some(Project(adjustedColumns, query))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity // config, and do various checks: @@ -315,7 +309,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -367,7 +361,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 // 2) Target tables have column metadata - Alias(Cast(actual, expected.dataType), expected.name)( + Alias(cast(actual, expected.dataType), expected.name)( explicitMetadata = Option(expected.metadata)) } } @@ -382,7 +376,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved => table match { - case relation: CatalogRelation => + case relation: HiveTableRelation => val metadata = relation.tableMeta preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames) case LogicalRelation(h: HadoopFsRelation, _, catalogTable) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b91d077442557..91d59473e5ed6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf /** @@ -37,6 +38,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def adaptiveExecutionEnabled: Boolean = conf.adaptiveExecutionEnabled + private def adaptiveExecutionDisabledForJoin: Boolean = conf.adaptiveExecutionDisabledForJoining + private def minNumPostShufflePartitions: Option[Int] = { val minNumPostShufflePartitions = conf.minNumPostShufflePartitions if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None @@ -62,7 +65,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def withExchangeCoordinator( children: Seq[SparkPlan], - requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { + requiredChildDistributions: Seq[Distribution], + disableAdaptiveExecution: Boolean = false): Seq[SparkPlan] = { val supportsCoordinator = if (children.exists(_.isInstanceOf[ShuffleExchange])) { // Right now, ExchangeCoordinator only support HashPartitionings. @@ -87,7 +91,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } val withCoordinator = - if (adaptiveExecutionEnabled && supportsCoordinator) { + if (adaptiveExecutionEnabled && !disableAdaptiveExecution && supportsCoordinator) { val coordinator = new ExchangeCoordinator( children.length, @@ -230,7 +234,19 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // at here for now. // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. - children = withExchangeCoordinator(children, requiredChildDistributions) + + // We have observed some performance issue when enabling adaptive execution in performing + // joining. It is hard to predict how many results the joining will produce. When the estimation + // from previous stage is within the postShuffleSize estimation, it will produce only one + // partition and makes the performance bad. We will disable adaptive execution for + // joining now till we figure out a better way of size estimation in joining. + val disableAdaptiveExecutionForJoin = + operator.isInstanceOf[SortMergeJoinExec] && + adaptiveExecutionEnabled && + adaptiveExecutionDisabledForJoin + + children = + withExchangeCoordinator(children, requiredChildDistributions, disableAdaptiveExecutionForJoin) // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index d993ea6c6cef9..4b52f3e4c49b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -58,6 +59,24 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { child.executeBroadcast() } + + // `ReusedExchangeExec` can have distinct set of output attribute ids from its child, we need + // to update the attribute ids in `outputPartitioning` and `outputOrdering`. + private lazy val updateAttr: Expression => Expression = { + val originalAttrToNewAttr = AttributeMap(child.output.zip(output)) + e => e.transform { + case attr: Attribute => originalAttrToNewAttr.getOrElse(attr, attr) + } + } + + override def outputPartitioning: Partitioning = child.outputPartitioning match { + case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case other => other + } + + override def outputOrdering: Seq[SortOrder] = { + child.outputOrdering.map(updateAttr(_).asInstanceOf[SortOrder]) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index f06544ea8ed04..eebe6ad2e7944 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -40,6 +40,9 @@ case class ShuffleExchange( child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { + // NOTE: coordinator can be null after serialization/deserialization, + // e.g. it can be null on the Executor side + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")) @@ -47,7 +50,7 @@ case class ShuffleExchange( val extraInfo = coordinator match { case Some(exchangeCoordinator) => s"(coordinator id: ${System.identityHashCode(exchangeCoordinator)})" - case None => "" + case _ => "" } val simpleNodeName = "Exchange" @@ -70,7 +73,7 @@ case class ShuffleExchange( // the plan. coordinator match { case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this) - case None => + case _ => } } @@ -117,7 +120,7 @@ case class ShuffleExchange( val shuffleRDD = exchangeCoordinator.postShuffleRDD(this) assert(shuffleRDD.partitions.length == newPartitioning.numPartitions) shuffleRDD - case None => + case _ => val shuffleDependency = prepareShuffleDependency() preparePostShuffleRDD(shuffleDependency) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 0bc261d593df4..69715ab1f675f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -257,8 +257,8 @@ case class BroadcastHashJoinExec( s""" |boolean $conditionPassed = true; |${eval.trim} - |${ev.code} |if ($matched != null) { + | ${ev.code} | $conditionPassed = !${ev.isNull} && ${ev.value}; |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index f380986951317..4d261dd422bc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -35,11 +35,12 @@ class UnsafeCartesianRDD( left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int, + inMemoryBufferThreshold: Int, spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) @@ -71,9 +72,12 @@ case class CartesianProductExec( val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold - - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold) + val pair = new UnsafeCartesianRDD( + leftResults, + rightResults, + right.output.size, + sqlContext.conf.cartesianProductExecBufferInMemoryThreshold, + sqlContext.conf.cartesianProductExecBufferSpillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index c6aae1a4db2e4..70dada8b63ae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -82,7 +82,7 @@ case class SortMergeJoinExec( override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. - case Inner => + case _: InnerLike => val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => @@ -130,9 +130,14 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferSpillThreshold } + private def getInMemoryThreshold: Int = { + sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -158,6 +163,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -201,6 +207,7 @@ case class SortMergeJoinExec( keyOrdering, streamedIter = RowIterator.fromScala(leftIter), bufferedIter = RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) @@ -214,6 +221,7 @@ case class SortMergeJoinExec( keyOrdering, streamedIter = RowIterator.fromScala(rightIter), bufferedIter = RowIterator.fromScala(leftIter), + inMemoryThreshold, spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) @@ -247,6 +255,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -281,6 +290,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -290,6 +300,7 @@ case class SortMergeJoinExec( currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches if (currentRightMatches == null || currentRightMatches.length == 0) { + numOutputRows += 1 return true } var found = false @@ -321,6 +332,7 @@ case class SortMergeJoinExec( keyOrdering, RowIterator.fromScala(leftIter), RowIterator.fromScala(rightIter), + inMemoryThreshold, spillThreshold ) private[this] val joinRow = new JoinedRow @@ -371,6 +383,7 @@ case class SortMergeJoinExec( keys: Seq[Expression], input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row + ctx.currentVars = null keys.map(BindReferences.bindReference(_, input).genCode(ctx)) } @@ -418,8 +431,10 @@ case class SortMergeJoinExec( val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold - ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") + ctx.addMutableState(clsName, matches, + s"$matches = new $clsName($inMemoryThreshold, $spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -624,6 +639,9 @@ case class SortMergeJoinExec( * @param streamedIter an input whose rows will be streamed. * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that * have the same join key. + * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by + * internal buffer + * @param spillThreshold Threshold for number of rows to be spilled by internal buffer */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -631,7 +649,8 @@ private[joins] class SortMergeJoinScanner( keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, bufferedIter: RowIterator, - bufferThreshold: Int) { + inMemoryThreshold: Int, + spillThreshold: Int) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -642,7 +661,8 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold) + private[this] val bufferedMatches = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 48c7b80bffe03..34391818f3b9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -361,8 +362,11 @@ object MapGroupsExec { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, + timeoutConf: GroupStateTimeout, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, new GroupStateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => { + func(key, values, GroupStateImpl.createForBatch(timeoutConf)) + } new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 408c8f81f17ba..77bc0ba5548dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -169,13 +169,15 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( */ private def compact(batchId: Long, logs: Array[T]): Boolean = { val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) - val allLogs = validBatches.flatMap(batchId => super.get(batchId)).flatten ++ logs - if (super.add(batchId, compactLogs(allLogs).toArray)) { - true - } else { - // Return false as there is another writer. - false - } + val allLogs = validBatches.map { id => + super.get(id).getOrElse { + throw new IllegalStateException( + s"${batchIdToPath(id)} doesn't exist when compacting batch $batchId " + + s"(compactInterval: $compactInterval)") + } + }.flatten ++ logs + // Return false as there is another writer. + super.add(batchId, compactLogs(allLogs).toArray) } /** @@ -190,7 +192,13 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( if (latestId >= 0) { try { val logs = - getAllValidBatches(latestId, compactInterval).flatMap(id => super.get(id)).flatten + getAllValidBatches(latestId, compactInterval).map { id => + super.get(id).getOrElse { + throw new IllegalStateException( + s"${batchIdToPath(id)} doesn't exist " + + s"(latestId: $latestId, compactInterval: $compactInterval)") + } + }.flatten return compactLogs(logs).toArray } catch { case e: IOException => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 25cf609fc336e..55e7508b2ed29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -27,27 +27,25 @@ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.AccumulatorV2 /** Class for collecting event time stats with an accumulator */ -case class EventTimeStats(var max: Long, var min: Long, var sum: Long, var count: Long) { +case class EventTimeStats(var max: Long, var min: Long, var avg: Double, var count: Long) { def add(eventTime: Long): Unit = { this.max = math.max(this.max, eventTime) this.min = math.min(this.min, eventTime) - this.sum += eventTime this.count += 1 + this.avg += (eventTime - avg) / count } def merge(that: EventTimeStats): Unit = { this.max = math.max(this.max, that.max) this.min = math.min(this.min, that.min) - this.sum += that.sum this.count += that.count + this.avg += (that.avg - this.avg) * that.count / this.count } - - def avg: Long = sum / count } object EventTimeStats { def zero: EventTimeStats = EventTimeStats( - max = Long.MinValue, min = Long.MaxValue, sum = 0L, count = 0L) + max = Long.MinValue, min = Long.MaxValue, avg = 0.0, count = 0L) } /** Accumulator that collects stats on event time in a batch. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 07ec4e9429e42..2a652920c10c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -53,6 +53,26 @@ object FileStreamSink extends Logging { case _ => false } } + + /** + * Returns true if the path is the metadata dir or its ancestor is the metadata dir. + * E.g.: + * - ancestorIsMetadataDirectory(/.../_spark_metadata) => true + * - ancestorIsMetadataDirectory(/.../_spark_metadata/0) => true + * - ancestorIsMetadataDirectory(/a/b/c) => false + */ + def ancestorIsMetadataDirectory(path: Path, hadoopConf: Configuration): Boolean = { + val fs = path.getFileSystem(hadoopConf) + var currentPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + while (currentPath != null) { + if (currentPath.getName == FileStreamSink.metadataDir) { + return true + } else { + currentPath = currentPath.getParent + } + } + return false + } } /** @@ -91,15 +111,6 @@ class FileStreamSink( case _ => // Do nothing } - // Get the actual partition columns as attributes after matching them by name with - // the given columns names. - val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col => - val nameEquality = data.sparkSession.sessionState.conf.resolver - data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}") - } - } - FileFormatWriter.write( sparkSession = sparkSession, queryExecution = data.queryExecution, @@ -107,7 +118,7 @@ class FileStreamSink( committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, - partitionColumns = partitionColumns, + partitionColumnNames = partitionColumnNames, bucketSpec = None, refreshFunction = _ => (), options = options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index 8d718b2164d22..c9939ac1db746 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.net.URI + import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization @@ -47,7 +49,8 @@ case class SinkFileStatus( action: String) { def toFileStatus: FileStatus = { - new FileStatus(size, isDir, blockReplication, blockSize, modificationTime, new Path(path)) + new FileStatus( + size, isDir, blockReplication, blockSize, modificationTime, new Path(new URI(path))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 33e6a1d5d6e18..8628471fdb925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -115,7 +115,10 @@ class FileStreamSourceLog( Map.empty[Long, Option[Array[FileEntry]]] } - (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1) + val batches = + (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1) + HDFSMetadataLog.verifyBatchIds(batches.map(_._1), startId, endId) + batches } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index e42df5dd61c70..3ceb4cf84a413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -120,7 +120,7 @@ case class FlatMapGroupsWithStateExec( val filteredIter = watermarkPredicateForData match { case Some(predicate) if timeoutConf == EventTimeTimeout => iter.filter(row => !predicate.eval(row)) - case None => + case _ => iter } @@ -215,7 +215,7 @@ case class FlatMapGroupsWithStateExec( val keyObj = getKeyObj(keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) - val keyedState = new GroupStateImpl( + val keyedState = GroupStateImpl.createForStreaming( stateObjOption, batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), @@ -230,6 +230,20 @@ case class FlatMapGroupsWithStateExec( // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { + + val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp + // If the state has not yet been set but timeout has been set, then + // we have to generate a row to save the timeout. However, attempting serialize + // null using case class encoder throws - + // java.lang.NullPointerException: Null value appeared in non-nullable field: + // If the schema is inferred from a Scala tuple / case class, or a Java bean, please + // try to use scala.Option[_] or other nullable types. + if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { + throw new IllegalStateException( + "Cannot set timeout when state is not defined, that is, state has not been" + + "initialized or has been removed") + } + if (keyedState.hasRemoved) { store.remove(keyRow) numUpdatedStateRows += 1 @@ -239,7 +253,6 @@ case class FlatMapGroupsWithStateExec( case Some(row) => getTimeoutTimestamp(row) case None => NO_TIMESTAMP } - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 148d92247d6f0..4401e86936af9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -38,20 +38,13 @@ import org.apache.spark.unsafe.types.CalendarInterval * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ -private[sql] class GroupStateImpl[S]( +private[sql] class GroupStateImpl[S] private( optionalValue: Option[S], batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, override val hasTimedOut: Boolean) extends GroupState[S] { - // Constructor to create dummy state when using mapGroupsWithState in a batch query - def this(optionalValue: Option[S]) = this( - optionalValue, - batchProcessingTimeMs = NO_TIMESTAMP, - eventTimeWatermarkMs = NO_TIMESTAMP, - timeoutConf = GroupStateTimeout.NoTimeout, - hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined private var updated: Boolean = false // whether value has been updated (but not removed) @@ -91,7 +84,6 @@ private[sql] class GroupStateImpl[S]( defined = false updated = false removed = true - timeoutTimestamp = NO_TIMESTAMP } override def setTimeoutDuration(durationMs: Long): Unit = { @@ -100,21 +92,10 @@ private[sql] class GroupStateImpl[S]( "Cannot set timeout duration without enabling processing time timeout in " + "map/flatMapGroupsWithState") } - if (!defined) { - throw new IllegalStateException( - "Cannot set timeout information without any state value, " + - "state has either not been initialized, or has already been removed") - } - if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") } - if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { - timeoutTimestamp = durationMs + batchProcessingTimeMs - } else { - // This is being called in a batch query, hence no processing timestamp. - // Just ignore any attempts to set timeout. - } + timeoutTimestamp = durationMs + batchProcessingTimeMs } override def setTimeoutDuration(duration: String): Unit = { @@ -135,12 +116,7 @@ private[sql] class GroupStateImpl[S]( s"Timeout timestamp ($timestampMs) cannot be earlier than the " + s"current watermark ($eventTimeWatermarkMs)") } - if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { - timeoutTimestamp = timestampMs - } else { - // This is being called in a batch query, hence no processing timestamp. - // Just ignore any attempts to set timeout. - } + timeoutTimestamp = timestampMs } @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") @@ -213,11 +189,6 @@ private[sql] class GroupStateImpl[S]( "Cannot set timeout timestamp without enabling event time timeout in " + "map/flatMapGroupsWithState") } - if (!defined) { - throw new IllegalStateException( - "Cannot set timeout timestamp without any state value, " + - "state has either not been initialized, or has already been removed") - } } } @@ -225,4 +196,23 @@ private[sql] class GroupStateImpl[S]( private[sql] object GroupStateImpl { // Value used represent the lack of valid timestamp as a long val NO_TIMESTAMP = -1L + + def createForStreaming[S]( + optionalValue: Option[S], + batchProcessingTimeMs: Long, + eventTimeWatermarkMs: Long, + timeoutConf: GroupStateTimeout, + hasTimedOut: Boolean): GroupStateImpl[S] = { + new GroupStateImpl[S]( + optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut) + } + + def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = { + new GroupStateImpl[Any]( + optionalValue = None, + batchProcessingTimeMs = NO_TIMESTAMP, + eventTimeWatermarkMs = NO_TIMESTAMP, + timeoutConf, + hasTimedOut = false) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 46bfc297931fb..5f8973fd09460 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -123,7 +123,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: serialize(metadata, output) return Some(tempPath) } finally { - IOUtils.closeQuietly(output) + output.close() } } catch { case e: FileAlreadyExistsException => @@ -211,13 +211,17 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] = { + assert(startId.isEmpty || endId.isEmpty || startId.get <= endId.get) val files = fileManager.list(metadataPath, batchFilesFilter) val batchIds = files .map(f => pathToBatchId(f.getPath)) .filter { batchId => (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) - } - batchIds.sorted.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { + }.sorted + + verifyBatchIds(batchIds, startId, endId) + + batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => (batchId, metadataOption.get) } @@ -437,4 +441,51 @@ object HDFSMetadataLog { } } } + + /** + * Verify if batchIds are continuous and between `startId` and `endId`. + * + * @param batchIds the sorted ids to verify. + * @param startId the start id. If it's set, batchIds should start with this id. + * @param endId the start id. If it's set, batchIds should end with this id. + */ + def verifyBatchIds(batchIds: Seq[Long], startId: Option[Long], endId: Option[Long]): Unit = { + // Verify that we can get all batches between `startId` and `endId`. + if (startId.isDefined || endId.isDefined) { + if (batchIds.isEmpty) { + throw new IllegalStateException(s"batch ${startId.orElse(endId).get} doesn't exist") + } + if (startId.isDefined) { + val minBatchId = batchIds.head + assert(minBatchId >= startId.get) + if (minBatchId != startId.get) { + val missingBatchIds = startId.get to minBatchId + throw new IllegalStateException( + s"batches (${missingBatchIds.mkString(", ")}) don't exist " + + s"(startId: $startId, endId: $endId)") + } + } + + if (endId.isDefined) { + val maxBatchId = batchIds.last + assert(maxBatchId <= endId.get) + if (maxBatchId != endId.get) { + val missingBatchIds = maxBatchId to endId.get + throw new IllegalStateException( + s"batches (${missingBatchIds.mkString(", ")}) don't exist " + + s"(startId: $startId, endId: $endId)") + } + } + } + + if (batchIds.nonEmpty) { + val minBatchId = batchIds.head + val maxBatchId = batchIds.last + val missingBatchIds = (minBatchId to maxBatchId).toSet -- batchIds + if (missingBatchIds.nonEmpty) { + throw new IllegalStateException(s"batches (${missingBatchIds.mkString(", ")}) " + + s"don't exist (startId: $startId, endId: $endId)") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index 5551d12fa8ad2..b84e6ce64c611 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -40,7 +40,7 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.inputRowsPerSecond) + registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 693933f95a231..db46fcd9dfe78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.text.SimpleDateFormat -import java.util.{Date, TimeZone, UUID} +import java.util.{Date, UUID} import scala.collection.mutable import scala.collection.JavaConverters._ @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent @@ -82,7 +83,7 @@ trait ProgressReporter extends Logging { private var lastNoDataProgressEventTime = Long.MinValue private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 - timestampFormat.setTimeZone(TimeZone.getTimeZone("UTC")) + timestampFormat.setTimeZone(DateTimeUtils.getTimeZone("UTC")) @volatile protected var currentStatus: StreamingQueryStatus = { @@ -266,7 +267,7 @@ trait ProgressReporter extends Logging { Map( "max" -> stats.max, "min" -> stats.min, - "avg" -> stats.avg).mapValues(formatTimestamp) + "avg" -> stats.avg.toLong).mapValues(formatTimestamp) }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp ExecutionStats(numInputRows, stateOperators, eventTimeStats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala new file mode 100644 index 0000000000000..e61a8eb628891 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types._ +import org.apache.spark.util.{ManualClock, SystemClock} + +/** + * A source that generates increment long values with timestamps. Each generated row has two + * columns: a timestamp column for the generated time and an auto increment long column starting + * with 0L. + * + * This source supports the following options: + * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second. + * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed + * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer + * seconds. + * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the + * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may + * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. + */ +class RateSourceProvider extends StreamSourceProvider with DataSourceRegister { + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = + (shortName(), RateSourceProvider.SCHEMA) + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val params = CaseInsensitiveMap(parameters) + + val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L) + if (rowsPerSecond <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " + + "must be positive") + } + + val rampUpTimeSeconds = + params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L) + if (rampUpTimeSeconds < 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " + + "must not be negative") + } + + val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse( + sqlContext.sparkContext.defaultParallelism) + if (numPartitions <= 0) { + throw new IllegalArgumentException( + s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " + + "must be positive") + } + + new RateStreamSource( + sqlContext, + metadataPath, + rowsPerSecond, + rampUpTimeSeconds, + numPartitions, + params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing + ) + } + override def shortName(): String = "rate" +} + +object RateSourceProvider { + val SCHEMA = + StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil) + + val VERSION = 1 +} + +class RateStreamSource( + sqlContext: SQLContext, + metadataPath: String, + rowsPerSecond: Long, + rampUpTimeSeconds: Long, + numPartitions: Int, + useManualClock: Boolean) extends Source with Logging { + + import RateSourceProvider._ + import RateStreamSource._ + + val clock = if (useManualClock) new ManualClock else new SystemClock + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private val startTimeMs = { + val metadataLog = + new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + val version = parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */ + @volatile private var lastTimeMs = startTimeMs + + override def schema: StructType = RateSourceProvider.SCHEMA + + override def getOffset: Option[Offset] = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs))) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema) + } + + val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + + val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v => + val relative = math.round((v - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v) + } + sqlContext.internalCreateDataFrame(rdd, schema) + } + + override def stop(): Unit = {} + + override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]" +} + +object RateStreamSource { + + /** Calculate the end value we will emit at the time `seconds`. */ + def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = { + // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10 + // Then speedDeltaPerSecond = 2 + // + // seconds = 0 1 2 3 4 5 6 + // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds) + // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2 + val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1) + if (seconds <= rampUpTimeSeconds) { + // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to + // avoid overflow + if (seconds % 2 == 1) { + (seconds + 1) / 2 * speedDeltaPerSecond * seconds + } else { + seconds / 2 * speedDeltaPerSecond * (seconds + 1) + } + } else { + // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds + val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds) + rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bcf0d970f7ec1..33f81d98ca593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.ReentrantLock +import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -148,15 +149,18 @@ class StreamExecution( "logicalPlan must be initialized in StreamExecutionThread " + s"but the current thread was ${Thread.currentThread}") var nextSourceId = 0L + val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val _logicalPlan = analyzedPlan.transform { - case StreamingRelation(dataSource, _, output) => - // Materialize source to avoid creating it in every batch - val metadataPath = s"$checkpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) - nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. - StreamingExecutionRelation(source, output) + case streamingRelation@StreamingRelation(dataSource, _, output) => + toExecutionRelationMap.getOrElseUpdate(streamingRelation, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$checkpointRoot/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + }) } sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } uniqueSources = sources.distinct @@ -252,6 +256,8 @@ class StreamExecution( */ private def runBatches(): Unit = { try { + sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, + interruptOnCancel = true) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -289,6 +295,7 @@ class StreamExecution( if (currentBatchId < 0) { // We'll do this initialization only once populateStartOffsets(sparkSessionToRunBatches) + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) logDebug(s"Stream running from $committedOffsets to $availableOffsets") } else { constructNextBatch() @@ -308,6 +315,7 @@ class StreamExecution( logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) } else { currentStatus = currentStatus.copy(isDataAvailable = false) updateStatusMessage("Waiting for data to arrive") @@ -421,7 +429,10 @@ class StreamExecution( availableOffsets = nextOffsets.toStreamProgress(sources) /* Initialize committed offsets to a committed batch, which at this * is the second latest batch id in the offset log. */ - offsetLog.get(latestBatchId - 1).foreach { secondLatestBatchId => + if (latestBatchId != 0) { + val secondLatestBatchId = offsetLog.get(latestBatchId - 1).getOrElse { + throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist") + } committedOffsets = secondLatestBatchId.toStreamProgress(sources) } @@ -560,10 +571,14 @@ class StreamExecution( // Now that we've updated the scheduler's persistent checkpoint, it is safe for the // sources to discard data from the previous batch. - val prevBatchOff = offsetLog.get(currentBatchId - 1) - if (prevBatchOff.isDefined) { - prevBatchOff.get.toStreamProgress(sources).foreach { - case (src, off) => src.commit(off) + if (currentBatchId != 0) { + val prevBatchOff = offsetLog.get(currentBatchId - 1) + if (prevBatchOff.isDefined) { + prevBatchOff.get.toStreamProgress(sources).foreach { + case (src, off) => src.commit(off) + } + } else { + throw new IllegalStateException(s"batch $currentBatchId doesn't exist") } } @@ -623,7 +638,8 @@ class StreamExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) val triggerLogicalPlan = withNewSources transformAllExpressions { - case a: Attribute if replacementMap.contains(a) => replacementMap(a) + case a: Attribute if replacementMap.contains(a) => + replacementMap(a).withMetadata(a.metadata) case ct: CurrentTimestamp => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, ct.dataType) @@ -684,8 +700,11 @@ class StreamExecution( // intentionally state.set(TERMINATED) if (microBatchThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) microBatchThread.interrupt() microBatchThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) } logInfo(s"Query $prettyIdString was stopped") } @@ -758,7 +777,7 @@ class StreamExecution( if (streamDeathCause != null) { throw streamDeathCause } - if (noNewData) { + if (noNewData || !isActive) { return } } @@ -825,6 +844,11 @@ class StreamExecution( } } + private def getBatchDescriptionString: String = { + val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString + Option(name).map(_ + "
    ").getOrElse("") + + s"id = $id
    runId = $runId
    batch = $batchDescription" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 1426728f9b550..ef48fffe1d980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.nio.channels.ClosedChannelException import java.util.Locale import scala.collection.JavaConverters._ @@ -202,13 +203,22 @@ private[state] class HDFSBackedStateStoreProvider( /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") + try { + state = ABORTED + if (tempDeltaFileStream != null) { + tempDeltaFileStream.close() + } + if (tempDeltaFile != null) { + fs.delete(tempDeltaFile, true) + } + } catch { + case c: ClosedChannelException => + // This can happen when underlying file output stream has been closed before the + // compression stream. + logDebug(s"Error aborting version $newVersion into $this", c) - state = ABORTED - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null) { - fs.delete(tempDeltaFile, true) + case e: Exception => + logWarning(s"Error aborting version $newVersion into $this", e) } logInfo(s"Aborted version $newVersion for $this") } @@ -438,9 +448,11 @@ private[state] class HDFSBackedStateStoreProvider( private def writeSnapshotFile(version: Long, map: MapType): Unit = { val fileToWrite = snapshotFile(version) + val tempFile = + new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}") var output: DataOutputStream = null Utils.tryWithSafeFinally { - output = compressStream(fs.create(fileToWrite, false)) + output = compressStream(fs.create(tempFile, false)) val iter = map.entrySet().iterator() while(iter.hasNext) { val entry = iter.next() @@ -455,6 +467,12 @@ private[state] class HDFSBackedStateStoreProvider( } { if (output != null) output.close() } + if (fs.exists(fileToWrite)) { + // Skip rename if the file is alreayd created. + fs.delete(tempFile, true) + } else if (!fs.rename(tempFile, fileToWrite)) { + throw new IOException(s"Failed to rename $tempFile to $fileToWrite") + } logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 8dbda298c87bc..d4de046787b9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -102,8 +102,13 @@ trait WatermarkSupport extends UnaryExecNode { } /** Predicate based on keys that matches data older than the watermark */ - lazy val watermarkPredicateForKeys: Option[Predicate] = - watermarkExpression.map(newPredicate(_, keyExpressions)) + lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.flatMap { e => + if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) { + Some(newPredicate(e, keyExpressions)) + } else { + None + } + } /** Predicate based on the child output that matches data older than the watermark. */ lazy val watermarkPredicateForData: Option[Predicate] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 23fc0bd0bce13..460fc946c3e6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -29,7 +29,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging private val listener = parent.listener override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized { - val parameterExecutionId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id")) require(parameterExecutionId != null && parameterExecutionId.nonEmpty, "Missing execution id parameter") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index c9f5d3b3d92d7..dfa1100c37a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,10 +145,13 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - updateProjection(join(buffer, input)) + // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that + // MutableProjection makes copies of the complex input objects it buffer. + val copy = input.copy() + updateProjection(join(buffer, copy)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, input) + imperatives(i).update(buffer, copy) i += 1 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 950a6794a74a3..b9c932ae21727 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -282,6 +282,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. @@ -312,7 +313,8 @@ case class WindowExec( val inputFields = child.output.length val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index af2b4fb92062b..156002ef58fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -195,15 +195,6 @@ private[window] final class SlidingWindowFunctionFrame( override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 - // Add all rows to the buffer for which the input row value is equal to or less than - // the output row upper bound. - while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { - buffer.add(nextRow.copy()) - nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) - inputHighIndex += 1 - bufferUpdated = true - } - // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { @@ -212,6 +203,19 @@ private[window] final class SlidingWindowFunctionFrame( bufferUpdated = true } + // Add all rows to the buffer for which the input row value is equal to or less than + // the output row upper bound. + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) { + inputLowIndex += 1 + } else { + buffer.add(nextRow.copy()) + bufferUpdated = true + } + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) + inputHighIndex += 1 + } + // Only recalculate and update when the buffer changes. if (bufferUpdated) { processor.initialize(input.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f07e04368389f..35601cbdc564a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -23,13 +23,13 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import scala.util.control.NonFatal -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.internal.SQLConf @@ -1019,7 +1019,8 @@ object functions { * @since 1.5.0 */ def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) + Dataset[T](df.sparkSession, + ResolvedHint(df.logicalPlan, HintInfo(isBroadcastable = Option(true))))(df.exprEnc) } /** @@ -1557,6 +1558,15 @@ object functions { */ def floor(columnName: String): Column = floor(Column(columnName)) + /** + * Computes a floating-point remainder value. The result has the same sign as the numerator. + * + * @group math_funcs + */ + def fmod(numerator: Column, denominator: Column): Column = withExpr { + Fmod(numerator.expr, denominator.expr) + } + /** * Returns the greatest value of the list of values, skipping null values. * This function takes at least 2 parameters. It will return null iff all parameters are null. @@ -2291,7 +2301,8 @@ object functions { } /** - * Left-pad the string column with + * Left-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -2349,7 +2360,8 @@ object functions { def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** - * Right-padded with pad to a length of len. + * Right-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. * * @group string_funcs * @since 1.5.0 @@ -2793,8 +2805,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window( timeColumn: Column, windowDuration: String, @@ -2847,8 +2857,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { window(timeColumn, windowDuration, slideDuration, "0 second") } @@ -2886,8 +2894,6 @@ object functions { * @group datetime_funcs * @since 2.0.0 */ - @Experimental - @InterfaceStability.Evolving def window(timeColumn: Column, windowDuration: String): Column = { window(timeColumn, windowDuration, windowDuration, "0 second") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2b14eca919fa4..2a801d87b12eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -63,6 +63,11 @@ abstract class BaseSessionStateBuilder( */ protected def newBuilder: NewBuilder + /** + * Session extensions defined in the [[SparkSession]]. + */ + protected def extensions: SparkSessionExtensions = session.extensions + /** * Extract entries from `SparkConf` and put them in the `SQLConf` */ @@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder( * * Note: this depends on the `conf` field. */ - protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + protected lazy val sqlParser: ParserInterface = { + extensions.buildParser(session, new SparkSqlParser(conf)) + } /** * ResourceLoader that is used to load function resources and jars. @@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildResolutionRules(session) + } /** * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of @@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildPostHocResolutionRules(session) + } /** * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating @@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + protected def customCheckRules: Seq[LogicalPlan => Unit] = { + extensions.buildCheckRules(session) + } /** * Logical query plan optimizer. @@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `optimizer` function. */ - protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = { + extensions.buildOptimizerRules(session) + } /** * Planner that converts optimized logical plans to physical plans. @@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `planner` function. */ - protected def customPlanningStrategies: Seq[Strategy] = Nil + protected def customPlanningStrategies: Seq[Strategy] = { + extensions.buildPlannerStrategies(session) + } /** * Create a query execution object. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index aebb663df5c92..0b8e53868c999 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -98,14 +99,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(tables, sparkSession) } + /** + * Returns a Table for the given table/view or temporary view. + * + * Note that this function requires the table already exists in the Catalog. + * + * If the table metadata retrieval failed due to any reason (e.g., table serde class + * is not accessible or the table type is not accepted by Spark SQL), this function + * still returns the corresponding Table without the description and tableType) + */ private def makeTable(tableIdent: TableIdentifier): Table = { - val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val metadata = try { + Some(sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)) + } catch { + case NonFatal(_) => None + } val isTemp = sessionCatalog.isTemporaryTable(tableIdent) new Table( name = tableIdent.table, - database = metadata.identifier.database.orNull, - description = metadata.comment.orNull, - tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name, + database = metadata.map(_.identifier.database).getOrElse(tableIdent.database).orNull, + description = metadata.map(_.comment.orNull).orNull, + tableType = if (isTemp) "TEMPORARY" else metadata.map(_.tableType.name).orNull, isTemporary = isTemp) } @@ -197,7 +211,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { - makeTable(TableIdentifier(tableName, Option(dbName))) + if (tableExists(dbName, tableName)) { + makeTable(TableIdentifier(tableName, Option(dbName))) + } else { + throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 0289471bf841a..7202f1222d10f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.internal +import java.net.URL +import java.util.Locale + import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging @@ -86,35 +90,42 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = - SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( + lazy val externalCatalog: ExternalCatalog = { + val externalCatalog = SharedState.reflect[ExternalCatalog, SparkConf, Configuration]( SharedState.externalCatalogClassName(sparkContext.conf), sparkContext.conf, sparkContext.hadoopConfiguration) - // Create the default database if it doesn't exist. - { val defaultDbDefinition = CatalogDatabase( SessionCatalog.DEFAULT_DATABASE, "default database", CatalogUtils.stringToURI(warehousePath), Map()) - // Initialize default database if it doesn't exist + // Create default database if it doesn't exist if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { // There may be another Spark application creating default database at the same time, here we // set `ignoreIfExists = true` to avoid `DatabaseAlreadyExists` exception. externalCatalog.createDatabase(defaultDbDefinition, ignoreIfExists = true) } + + // Make sure we propagate external catalog events to the spark listener bus + externalCatalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + sparkContext.listenerBus.post(event) + } + }) + + externalCatalog } /** * A manager for global temporary views. */ - val globalTempViewManager: GlobalTempViewManager = { + lazy val globalTempViewManager: GlobalTempViewManager = { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. - val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase + val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase(Locale.ROOT) if (externalCatalog.databaseExists(globalTempDB)) { throw new SparkException( s"$globalTempDB is a system preserved database, please rename your existing database " + @@ -145,7 +156,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } -object SharedState { +object SharedState extends Logging { + try { + URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory()) + } catch { + case e: Error => + logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory") + } private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index f541996b651e9..20e634c06b610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -43,10 +43,6 @@ private case object OracleDialect extends JdbcDialect { // Not sure if there is a more robust way to identify the field as a float (or other // numeric types that do not specify a scale. case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case 1 => Option(BooleanType) - case 3 | 5 | 10 => Option(IntegerType) - case 19 if scale == 0L => Option(LongType) - case 19 if scale == 4L => Option(FloatType) case _ => None } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ff8b15b3ff3ff..86eeb2f7dd419 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -163,16 +163,13 @@ trait StreamSinkProvider { @InterfaceStability.Stable trait CreatableRelationProvider { /** - * Save the DataFrame to the destination and return a relation with the given parameters based on - * the contents of the given DataFrame. The mode specifies the expected behavior of createRelation - * when data already exists. - * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. - * Append mode means that when saving a DataFrame to a data source, if data already exists, - * contents of the DataFrame are expected to be appended to existing data. - * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, - * existing data is expected to be overwritten by the contents of the DataFrame. - * ErrorIfExists mode means that when saving a DataFrame to a data source, - * if data already exists, an exception is expected to be thrown. + * Saves a DataFrame to a destination (using data source-specific parameters) + * + * @param sqlContext SQLContext + * @param mode specifies what happens when the destination already exists + * @param parameters data source-specific parameters + * @param data DataFrame to save (i.e. the rows after executing the query) + * @return Relation with a known schema * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 746b2a94f102d..7e8e6394b4862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.execution.command.DDLUtils @@ -35,7 +35,6 @@ import org.apache.spark.sql.types.StructType * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { /** @@ -164,7 +163,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * Loads a JSON file stream and returns the results as a `DataFrame`. * * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `wholeFile` option to true. + * default. For JSON (one record per file), set the `multiLine` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -206,7 +205,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSXXX`): sets the string that * indicates a timestamp format. Custom date formats follow the formats at * `java.text.SimpleDateFormat`. This applies to timestamp type.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines, * per file
  • * * @@ -277,7 +276,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo *
  • `columnNameOfCorruptRecord` (default is the value specified in * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
  • - *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines.
  • + *
  • `multiLine` (default `false`): parse one record, which may span multiple lines.
  • * * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 0d2611f9bbcce..14e7df672cc58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.{AnalysisException, Dataset, ForeachWriter} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.DDLUtils @@ -29,13 +29,11 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} /** - * :: Experimental :: * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, * key-value stores, etc). Use `Dataset.writeStream` to access this. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index c659ac7fcf3d9..04a956b70b022 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -212,7 +212,7 @@ trait GroupState[S] extends LogicalGroupState[S] { @throws[IllegalArgumentException]("when updating with null") def update(newState: S): Unit - /** Remove this state. Note that this resets any timeout configuration as well. */ + /** Remove this state. */ def remove(): Unit /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala index 9ba1fc01cbd30..a033575d3d38f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/ProcessingTime.scala @@ -23,11 +23,10 @@ import scala.concurrent.duration.Duration import org.apache.commons.lang3.StringUtils -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.unsafe.types.CalendarInterval /** - * :: Experimental :: * A trigger that runs a query periodically based on the processing time. If `interval` is 0, * the query will run as fast as possible. * @@ -49,7 +48,6 @@ import org.apache.spark.unsafe.types.CalendarInterval * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") case class ProcessingTime(intervalMs: Long) extends Trigger { @@ -57,12 +55,10 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { } /** - * :: Experimental :: * Used to create [[ProcessingTime]] triggers for [[StreamingQuery]]s. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving @deprecated("use Trigger.ProcessingTime(intervalMs)", "2.2.0") object ProcessingTime { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 12a1bb1db5779..f2dfbe42260d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -19,16 +19,14 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.SparkSession /** - * :: Experimental :: * A handle to a query that is executing continuously in the background as new data arrives. * All these methods are thread-safe. * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving trait StreamingQuery { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 234a1166a1953..03aeb14de502a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.streaming -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception * that caused the failure. * @param message Message of this exception @@ -29,7 +28,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * @param endOffset Ending offset in json of the range of data in exception occurred * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryException private[sql]( private val queryDebugString: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index c376913516ef7..6aa82b89ede81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.streaming import java.util.UUID -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.scheduler.SparkListenerEvent /** - * :: Experimental :: * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. * @note The methods are not thread-safe as they may be called from different threads. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving abstract class StreamingQueryListener { @@ -66,32 +64,26 @@ abstract class StreamingQueryListener { /** - * :: Experimental :: * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving object StreamingQueryListener { /** - * :: Experimental :: * Base type of [[StreamingQueryListener]] events * @since 2.0.0 */ - @Experimental @InterfaceStability.Evolving trait Event extends SparkListenerEvent /** - * :: Experimental :: * Event representing the start of a query * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. * @param name User-specified name of the query, null if not specified. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryStartedEvent private[sql]( val id: UUID, @@ -99,17 +91,14 @@ object StreamingQueryListener { val name: String) extends Event /** - * :: Experimental :: * Event representing any progress updates in a query. * @param progress The query progress updates. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event /** - * :: Experimental :: * Event representing that termination of a query. * * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. @@ -118,7 +107,6 @@ object StreamingQueryListener { * with an exception. Otherwise, it will be `None`. * @since 2.1.0 */ - @Experimental @InterfaceStability.Evolving class QueryTerminatedEvent private[sql]( val id: UUID, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7810d9f6e9642..002c45413b4c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker @@ -34,12 +34,10 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{Clock, SystemClock, Utils} /** - * :: Experimental :: - * A class to manage all the [[StreamingQuery]] active on a `SparkSession`. + * A class to manage all the [[StreamingQuery]] active in a `SparkSession`. * * @since 2.0.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 687b1267825fe..a0c9bcc8929eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,10 +22,9 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Reports information about the instantaneous status of a streaming query. * * @param message A human readable description of what the stream is currently doing. @@ -35,7 +34,6 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} * * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryStatus protected[sql]( val message: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 35fe6b8605fad..5171852c48b9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -29,13 +29,11 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.annotation.InterfaceStability /** - * :: Experimental :: * Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger. */ -@Experimental @InterfaceStability.Evolving class StateOperatorProgress private[sql]( val numRowsTotal: Long, @@ -51,10 +49,11 @@ class StateOperatorProgress private[sql]( ("numRowsTotal" -> JInt(numRowsTotal)) ~ ("numRowsUpdated" -> JInt(numRowsUpdated)) } + + override def toString: String = prettyJson } /** - * :: Experimental :: * Information about progress made in the execution of a [[StreamingQuery]] during * a trigger. Each event relates to processing done for a single trigger of the streaming * query. Events are emitted even when no new data is available to be processed. @@ -80,7 +79,6 @@ class StateOperatorProgress private[sql]( * @param sources detailed statistics on data being read from each of the streaming sources. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class StreamingQueryProgress private[sql]( val id: UUID, @@ -139,7 +137,6 @@ class StreamingQueryProgress private[sql]( } /** - * :: Experimental :: * Information about progress made for a source in the execution of a [[StreamingQuery]] * during a trigger. See [[StreamingQueryProgress]] for more information. * @@ -152,7 +149,6 @@ class StreamingQueryProgress private[sql]( * Spark. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class SourceProgress protected[sql]( val description: String, @@ -191,14 +187,12 @@ class SourceProgress protected[sql]( } /** - * :: Experimental :: * Information about progress made for a sink in the execution of a [[StreamingQuery]] * during a trigger. See [[StreamingQueryProgress]] for more information. * * @param description Description of the source corresponding to this status. * @since 2.1.0 */ -@Experimental @InterfaceStability.Evolving class SinkProgress protected[sql]( val description: String) extends Serializable { diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index cfd7889b4ac2c..c6973bf41d34b 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,3 +1,7 @@ org.apache.spark.sql.sources.FakeSourceOne org.apache.spark.sql.sources.FakeSourceTwo org.apache.spark.sql.sources.FakeSourceThree +org.apache.spark.sql.sources.FakeSourceFour +org.apache.fakesource.FakeExternalSourceOne +org.apache.fakesource.FakeExternalSourceTwo +org.apache.fakesource.FakeExternalSourceThree diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql index f62b10ca0037b..492a405d7ebbd 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql @@ -32,3 +32,27 @@ select 1 - 2; select 2 * 5; select 5 % 3; select pmod(-7, 3); + +-- math functions +select cot(1); +select cot(null); +select cot(0); +select cot(-1); + +-- ceil and ceiling +select ceiling(0); +select ceiling(1); +select ceil(1234567890123456); +select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); + +-- floor +select floor(0); +select floor(1); +select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001 \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql index 5fae571945e41..629df59cff8b3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cast.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql @@ -40,4 +40,6 @@ SELECT CAST('-9223372036854775809' AS long); SELECT CAST('9223372036854775807' AS long); SELECT CAST('9223372036854775808' AS long); +DESC FUNCTION boolean; +DESC FUNCTION EXTENDED boolean; -- TODO: migrate all cast tests here. diff --git a/sql/core/src/test/resources/sql-tests/inputs/comparator.sql b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql new file mode 100644 index 0000000000000..3e2447723e576 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/comparator.sql @@ -0,0 +1,3 @@ +-- binary type +select x'00' < x'0f'; +select x'00' < x'ff'; diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql index 6de4cf0d5afa1..91b966829f8fb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -1,4 +1,5 @@ CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + OPTIONS (a '1', b '2') PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS COMMENT 'table_comment'; @@ -13,6 +14,8 @@ CREATE TEMPORARY VIEW temp_Data_Source_View CREATE VIEW v AS SELECT * FROM t; +ALTER TABLE t SET TBLPROPERTIES (e = '3'); + ALTER TABLE t ADD PARTITION (c='Us', d=1); DESCRIBE t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index f8135389a9e5a..8aff4cb524199 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; \ No newline at end of file +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; + +-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b); +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 9c8d851e36e9b..928f766b4add2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -49,7 +49,10 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1; -- group by ordinal followed by having select count(a), a from (select 1 as a) tmp group by 2 having a > 0; --- turn of group by ordinal +-- mixed cases: group-by ordinals and aliases +select a, a AS k, count(b) from data group by k, 1; + +-- turn off group by ordinal set spark.sql.groupByOrdinal=false; -- can now group by negative literal diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 4d0ed43153004..1e1384549a410 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -35,3 +35,28 @@ FROM testData; -- Aggregate with foldable input and multiple distinct groups. SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Aliases in SELECT could be used in GROUP BY +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; + +-- Aggregate functions cannot be used in GROUP BY +SELECT COUNT(b) AS k FROM testData GROUP BY k; + +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v); +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a; + +-- turn off group by aliases +set spark.sql.groupByAliases=false; + +-- Check analysis exceptions +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; + +-- Aggregate with empty input and non-empty GroupBy expressions. +SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; + +-- Aggregate with empty input and empty GroupBy expressions. +SELECT COUNT(1) FROM testData WHERE false; +SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 364c022d959dc..868a911e787f6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -13,3 +13,6 @@ SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; -- SPARK-11032: resolve having correctly SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); + +-- SPARK-20329: make sure we handle timezones correctly +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql index 2b5b692d29ef4..f1461032065ad 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -23,3 +23,7 @@ SELECT float(1), double(1), decimal(1); SELECT date("2014-04-04"), timestamp(date("2014-04-04")); -- error handling: only one argument SELECT string(1, 2); + +-- SPARK-21555: RuntimeReplaceable used in group by +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st); +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value"); diff --git a/sql/core/src/test/resources/sql-tests/inputs/struct.sql b/sql/core/src/test/resources/sql-tests/inputs/struct.sql index e56344dc4de80..93a1238ab18c2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/struct.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/struct.sql @@ -18,3 +18,10 @@ SELECT ID, STRUCT(ST.*,CAST(ID AS STRING) AS E) NST FROM tbl_x; -- Prepend a column to a struct SELECT ID, STRUCT(CAST(ID AS STRING) AS AA, ST.*) NST FROM tbl_x; + +-- Select a column from a struct +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x; +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x; + +-- Select an alias from a struct +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql new file mode 100644 index 0000000000000..c800fc3d49891 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -0,0 +1,69 @@ +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) +AS testData(val, cate); + +-- RowsBetween +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData +ORDER BY cate, val; +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; + +-- RangeBetween +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData +ORDER BY cate, val; +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; + +-- RangeBetween with reverse OrderBy +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; + +-- Window functions +SELECT val, cate, +max(val) OVER w AS max, +min(val) OVER w AS min, +min(val) OVER w AS min, +count(val) OVER w AS count, +sum(val) OVER w AS sum, +avg(val) OVER w AS avg, +stddev(val) OVER w AS stddev, +first_value(val) OVER w AS first_value, +first_value(val, true) OVER w AS first_value_ignore_null, +first_value(val, false) OVER w AS first_value_contain_null, +last_value(val) OVER w AS last_value, +last_value(val, true) OVER w AS last_value_ignore_null, +last_value(val, false) OVER w AS last_value_contain_null, +rank() OVER w AS rank, +dense_rank() OVER w AS dense_rank, +cume_dist() OVER w AS cume_dist, +percent_rank() OVER w AS percent_rank, +ntile(2) OVER w AS ntile, +row_number() OVER w AS row_number, +var_pop(val) OVER w AS var_pop, +var_samp(val) OVER w AS var_samp, +approx_count_distinct(val) OVER w AS approx_count_distinct +FROM testData +WINDOW w AS (PARTITION BY cate ORDER BY val) +ORDER BY cate, val; + +-- Null inputs +SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val; + +-- OrderBy not specified +SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val; + +-- Over clause is empty +SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val; + +-- first_value()/last_value() over () +SELECT val, cate, +first_value(false) OVER w AS first_value, +first_value(true, true) OVER w AS first_value_ignore_null, +first_value(false, false) OVER w AS first_value_contain_null, +last_value(false) OVER w AS last_value, +last_value(true, true) OVER w AS last_value_ignore_null, +last_value(false, false) OVER w AS last_value_contain_null +FROM testData +WINDOW w AS () +ORDER BY cate, val; diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out index ce42c016a7100..3811cd2c30986 100644 --- a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 44 -- !query 0 @@ -224,3 +224,135 @@ select pmod(-7, 3) struct -- !query 27 output 2 + + +-- !query 28 +select cot(1) +-- !query 28 schema +struct<> +-- !query 28 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 29 +select cot(null) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 30 +select cot(0) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 31 +select cot(-1) +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.AnalysisException +Undefined function: 'cot'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 7 + + +-- !query 32 +select ceiling(0) +-- !query 32 schema +struct +-- !query 32 output +0 + + +-- !query 33 +select ceiling(1) +-- !query 33 schema +struct +-- !query 33 output +1 + + +-- !query 34 +select ceil(1234567890123456) +-- !query 34 schema +struct +-- !query 34 output +1234567890123456 + + +-- !query 35 +select ceiling(1234567890123456) +-- !query 35 schema +struct +-- !query 35 output +1234567890123456 + + +-- !query 36 +select ceil(0.01) +-- !query 36 schema +struct +-- !query 36 output +1 + + +-- !query 37 +select ceiling(-0.10) +-- !query 37 schema +struct +-- !query 37 output +0 + + +-- !query 38 +select floor(0) +-- !query 38 schema +struct +-- !query 38 output +0 + + +-- !query 39 +select floor(1) +-- !query 39 schema +struct +-- !query 39 output +1 + + +-- !query 40 +select floor(1234567890123456) +-- !query 40 schema +struct +-- !query 40 output +1234567890123456 + + +-- !query 41 +select floor(0.01) +-- !query 41 schema +struct +-- !query 41 output +0 + + +-- !query 42 +select floor(-0.10) +-- !query 42 schema +struct +-- !query 42 output +-1 + + +-- !query 43 +select 1 > 0.00001 +-- !query 43 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 43 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out index bfa29d7d2d597..4e6353b1f332c 100644 --- a/sql/core/src/test/resources/sql-tests/results/cast.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 24 -- !query 0 @@ -176,3 +176,24 @@ SELECT CAST('9223372036854775808' AS long) struct -- !query 21 output NULL + + +-- !query 22 +DESC FUNCTION boolean +-- !query 22 schema +struct +-- !query 22 output +Class: org.apache.spark.sql.catalyst.expressions.Cast +Function: boolean +Usage: boolean(expr) - Casts the value `expr` to the target data type `boolean`. + + +-- !query 23 +DESC FUNCTION EXTENDED boolean +-- !query 23 schema +struct +-- !query 23 output +Class: org.apache.spark.sql.catalyst.expressions.Cast +Extended Usage:N/A. +Function: boolean +Usage: boolean(expr) - Casts the value `expr` to the target data type `boolean`. diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out index 678a3f0f0a3c6..ba8bc936f0c79 100644 --- a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -15,7 +15,6 @@ DESC test_change -- !query 1 schema struct -- !query 1 output -# col_name data_type comment a int b string c int @@ -35,7 +34,6 @@ DESC test_change -- !query 3 schema struct -- !query 3 output -# col_name data_type comment a int b string c int @@ -55,7 +53,6 @@ DESC test_change -- !query 5 schema struct -- !query 5 output -# col_name data_type comment a int b string c int @@ -94,7 +91,6 @@ DESC test_change -- !query 8 schema struct -- !query 8 output -# col_name data_type comment a int b string c int @@ -129,7 +125,6 @@ DESC test_change -- !query 12 schema struct -- !query 12 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -148,7 +143,6 @@ DESC test_change -- !query 14 schema struct -- !query 14 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -168,7 +162,6 @@ DESC test_change -- !query 16 schema struct -- !query 16 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -193,7 +186,6 @@ DESC test_change -- !query 18 schema struct -- !query 18 output -# col_name data_type comment a int this is column a b string #*02?` c int @@ -237,7 +229,6 @@ DESC test_change -- !query 23 schema struct -- !query 23 output -# col_name data_type comment a int this is column A b string #*02?` c int diff --git a/sql/core/src/test/resources/sql-tests/results/comparator.sql.out b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out new file mode 100644 index 0000000000000..afc7b5448b7b6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/comparator.sql.out @@ -0,0 +1,18 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 2 + + +-- !query 0 +select x'00' < x'0f' +-- !query 0 schema +struct<(X'00' < X'0F'):boolean> +-- !query 0 output +true + + +-- !query 1 +select x'00' < x'ff' +-- !query 1 schema +struct<(X'00' < X'FF'):boolean> +-- !query 1 output +true diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out index de10b29f3c65b..ab9f2783f06bb 100644 --- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -1,9 +1,10 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 31 +-- Number of queries: 32 -- !query 0 CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet + OPTIONS (a '1', b '2') PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS COMMENT 'table_comment' -- !query 0 schema @@ -42,7 +43,7 @@ struct<> -- !query 4 -ALTER TABLE t ADD PARTITION (c='Us', d=1) +ALTER TABLE t SET TBLPROPERTIES (e = '3') -- !query 4 schema struct<> -- !query 4 output @@ -50,11 +51,18 @@ struct<> -- !query 5 -DESCRIBE t +ALTER TABLE t ADD PARTITION (c='Us', d=1) -- !query 5 schema -struct +struct<> -- !query 5 output -# col_name data_type comment + + + +-- !query 6 +DESCRIBE t +-- !query 6 schema +struct +-- !query 6 output a string b int c string @@ -65,12 +73,11 @@ c string d string --- !query 6 +-- !query 7 DESC default.t --- !query 6 schema +-- !query 7 schema struct --- !query 6 output -# col_name data_type comment +-- !query 7 output a string b int c string @@ -81,12 +88,11 @@ c string d string --- !query 7 +-- !query 8 DESC TABLE t --- !query 7 schema +-- !query 8 schema struct --- !query 7 output -# col_name data_type comment +-- !query 8 output a string b int c string @@ -97,12 +103,11 @@ c string d string --- !query 8 +-- !query 9 DESC FORMATTED t --- !query 8 schema +-- !query 9 schema struct --- !query 8 output -# col_name data_type comment +-- !query 9 output a string b int c string @@ -123,16 +128,17 @@ Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] Comment table_comment +Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] Partition Provider Catalog --- !query 9 +-- !query 10 DESC EXTENDED t --- !query 9 schema +-- !query 10 schema struct --- !query 9 output -# col_name data_type comment +-- !query 10 output a string b int c string @@ -153,16 +159,17 @@ Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] Comment table_comment +Table Properties [e=3] Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] Partition Provider Catalog --- !query 10 +-- !query 11 DESC t PARTITION (c='Us', d=1) --- !query 10 schema +-- !query 11 schema struct --- !query 10 output -# col_name data_type comment +-- !query 11 output a string b int c string @@ -173,12 +180,11 @@ c string d string --- !query 11 +-- !query 12 DESC EXTENDED t PARTITION (c='Us', d=1) --- !query 11 schema +-- !query 12 schema struct --- !query 11 output -# col_name data_type comment +-- !query 12 output a string b int c string @@ -193,20 +199,21 @@ Database default Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 +Storage Properties [a=1, b=2] # Storage Information Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] -Location [not included in comparison]sql/core/spark-warehouse/t +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] --- !query 12 +-- !query 13 DESC FORMATTED t PARTITION (c='Us', d=1) --- !query 12 schema +-- !query 13 schema struct --- !query 12 output -# col_name data_type comment +-- !query 13 output a string b int c string @@ -221,39 +228,41 @@ Database default Table t Partition Values [c=Us, d=1] Location [not included in comparison]sql/core/spark-warehouse/t/c=Us/d=1 +Storage Properties [a=1, b=2] # Storage Information Num Buckets 2 Bucket Columns [`a`] Sort Columns [`b`] -Location [not included in comparison]sql/core/spark-warehouse/t +Location [not included in comparison]sql/core/spark-warehouse/t +Storage Properties [a=1, b=2] --- !query 13 +-- !query 14 DESC t PARTITION (c='Us', d=2) --- !query 13 schema +-- !query 14 schema struct<> --- !query 13 output +-- !query 14 output org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException Partition not found in table 't' database 'default': c -> Us d -> 2; --- !query 14 +-- !query 15 DESC t PARTITION (c='Us') --- !query 14 schema +-- !query 15 schema struct<> --- !query 14 output +-- !query 15 output org.apache.spark.sql.AnalysisException Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; --- !query 15 +-- !query 16 DESC t PARTITION (c='Us', d) --- !query 15 schema +-- !query 16 schema struct<> --- !query 15 output +-- !query 16 output org.apache.spark.sql.catalyst.parser.ParseException PARTITION specification is incomplete: `d`(line 1, pos 0) @@ -263,24 +272,11 @@ DESC t PARTITION (c='Us', d) ^^^ --- !query 16 -DESC temp_v --- !query 16 schema -struct --- !query 16 output -# col_name data_type comment -a string -b int -c string -d string - - -- !query 17 -DESC TABLE temp_v +DESC temp_v -- !query 17 schema struct -- !query 17 output -# col_name data_type comment a string b int c string @@ -288,11 +284,10 @@ d string -- !query 18 -DESC FORMATTED temp_v +DESC TABLE temp_v -- !query 18 schema struct -- !query 18 output -# col_name data_type comment a string b int c string @@ -300,11 +295,10 @@ d string -- !query 19 -DESC EXTENDED temp_v +DESC FORMATTED temp_v -- !query 19 schema struct -- !query 19 output -# col_name data_type comment a string b int c string @@ -312,11 +306,21 @@ d string -- !query 20 -DESC temp_Data_Source_View +DESC EXTENDED temp_v -- !query 20 schema struct -- !query 20 output -# col_name data_type comment +a string +b int +c string +d string + + +-- !query 21 +DESC temp_Data_Source_View +-- !query 21 schema +struct +-- !query 21 output intType int test comment test1 stringType string dateType date @@ -335,45 +339,42 @@ arrayType array structType struct --- !query 21 +-- !query 22 DESC temp_v PARTITION (c='Us', d=1) --- !query 21 schema +-- !query 22 schema struct<> --- !query 21 output +-- !query 22 output org.apache.spark.sql.AnalysisException DESC PARTITION is not allowed on a temporary view: temp_v; --- !query 22 +-- !query 23 DESC v --- !query 22 schema +-- !query 23 schema struct --- !query 22 output -# col_name data_type comment +-- !query 23 output a string b int c string d string --- !query 23 +-- !query 24 DESC TABLE v --- !query 23 schema +-- !query 24 schema struct --- !query 23 output -# col_name data_type comment +-- !query 24 output a string b int c string d string --- !query 24 +-- !query 25 DESC FORMATTED v --- !query 24 schema +-- !query 25 schema struct --- !query 24 output -# col_name data_type comment +-- !query 25 output a string b int c string @@ -388,15 +389,14 @@ Type VIEW View Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] -Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] +Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] --- !query 25 +-- !query 26 DESC EXTENDED v --- !query 25 schema +-- !query 26 schema struct --- !query 25 output -# col_name data_type comment +-- !query 26 output a string b int c string @@ -411,28 +411,20 @@ Type VIEW View Text SELECT * FROM t View Default Database default View Query Output Columns [a, b, c, d] -Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] - - --- !query 26 -DESC v PARTITION (c='Us', d=1) --- !query 26 schema -struct<> --- !query 26 output -org.apache.spark.sql.AnalysisException -DESC PARTITION is not allowed on a view: v; +Table Properties [view.query.out.col.3=d, view.query.out.col.0=a, view.query.out.numCols=4, view.default.database=default, view.query.out.col.1=b, view.query.out.col.2=c] -- !query 27 -DROP TABLE t +DESC v PARTITION (c='Us', d=1) -- !query 27 schema struct<> -- !query 27 output - +org.apache.spark.sql.AnalysisException +DESC PARTITION is not allowed on a view: v; -- !query 28 -DROP VIEW temp_v +DROP TABLE t -- !query 28 schema struct<> -- !query 28 output @@ -440,7 +432,7 @@ struct<> -- !query 29 -DROP VIEW temp_Data_Source_View +DROP VIEW temp_v -- !query 29 schema struct<> -- !query 29 output @@ -448,8 +440,16 @@ struct<> -- !query 30 -DROP VIEW v +DROP VIEW temp_Data_Source_View -- !query 30 schema struct<> -- !query 30 output + + +-- !query 31 +DROP VIEW v +-- !query 31 schema +struct<> +-- !query 31 output + diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index 825e8f5488c8b..ce7a16a4d0c81 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 26 +-- Number of queries: 29 -- !query 0 @@ -328,3 +328,50 @@ struct<> -- !query 25 output org.apache.spark.sql.AnalysisException grouping__id is deprecated; use grouping_id() instead; + + +-- !query 26 +SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2) +-- !query 26 schema +struct +-- !query 26 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL 1 3 +NULL 2 0 +NULL NULL 3 + + +-- !query 27 +SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b) +-- !query 27 schema +struct +-- !query 27 output +2 1 0 +2 NULL 0 +3 1 1 +3 2 -1 +3 NULL 0 +4 1 2 +4 2 0 +4 NULL 2 +5 2 1 +5 NULL 1 +NULL NULL 3 + + +-- !query 28 +SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k) +-- !query 28 schema +struct<(a + b):int,k:int,sum((a - b)):bigint> +-- !query 28 output +NULL 1 3 +NULL 2 0 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c0930bbde69a4..9ecbe19078dd6 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 19 +-- Number of queries: 20 -- !query 0 @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 +aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT)); -- !query 12 @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 +aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT)); -- !query 13 @@ -173,16 +173,26 @@ struct -- !query 17 -set spark.sql.groupByOrdinal=false +select a, a AS k, count(b) from data group by k, 1 -- !query 17 schema -struct +struct -- !query 17 output -spark.sql.groupByOrdinal false +1 1 2 +2 2 2 +3 3 2 -- !query 18 -select sum(b) from data group by -1 +set spark.sql.groupByOrdinal=false -- !query 18 schema -struct +struct -- !query 18 output +spark.sql.groupByOrdinal false + + +-- !query 19 +select sum(b) from data group by -1 +-- !query 19 schema +struct +-- !query 19 output 9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4b87d5161fc0e..42e82308ee1f0 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 25 -- !query 0 @@ -139,3 +139,91 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS struct -- !query 14 output 1 1 + + +-- !query 15 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +1 2 +2 2 +3 2 +NULL 1 + + +-- !query 16 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1 +-- !query 16 schema +struct +-- !query 16 output +2 2 +3 2 + + +-- !query 17 +SELECT COUNT(b) AS k FROM testData GROUP BY k +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`); + + +-- !query 18 +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 20 +set spark.sql.groupByAliases=false +-- !query 20 schema +struct +-- !query 20 output +spark.sql.groupByAliases false + + +-- !query 21 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 + + +-- !query 22 +SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a +-- !query 22 schema +struct +-- !query 22 output + + + +-- !query 23 +SELECT COUNT(1) FROM testData WHERE false +-- !query 23 schema +struct +-- !query 23 output +0 + + +-- !query 24 +SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t +-- !query 24 schema +struct<1:int> +-- !query 24 output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index e0923832673cb..d87ee5221647f 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query 0 @@ -38,3 +38,12 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) struct -- !query 3 output 1 + + +-- !query 4 +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1 +-- !query 4 schema +struct<(a + CAST(b AS BIGINT)):bigint> +-- !query 4 output +3 +7 diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 315e1730ce7df..fedabaee2237f 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -141,7 +141,7 @@ struct<> -- !query 13 output org.apache.spark.sql.AnalysisException -DataType invalidtype() is not supported.(line 1, pos 2) +DataType invalidtype is not supported.(line 1, pos 2) == SQL == a InvalidType diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 9f0b95994be53..e035505f15d28 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -88,7 +88,7 @@ Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nul == Physical Plan == *Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- *Range (0, 2, step=1, splits=None) ++- *Range (0, 2, step=1, splits=2) -- !query 9 @@ -122,3 +122,19 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException Function string accepts only one argument; line 1 pos 7 + + +-- !query 13 +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st) +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") +-- !query 14 schema +struct +-- !query 14 output +gamma 1 diff --git a/sql/core/src/test/resources/sql-tests/results/struct.sql.out b/sql/core/src/test/resources/sql-tests/results/struct.sql.out index 3e32f46195464..1da33bc736f0b 100644 --- a/sql/core/src/test/resources/sql-tests/results/struct.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/struct.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 9 -- !query 0 @@ -58,3 +58,33 @@ struct> 1 {"AA":"1","C":"gamma","D":"delta"} 2 {"AA":"2","C":"epsilon","D":"eta"} 3 {"AA":"3","C":"theta","D":"iota"} + + +-- !query 6 +SELECT ID, STRUCT(ST.*).C NST FROM tbl_x +-- !query 6 schema +struct +-- !query 6 output +1 gamma +2 epsilon +3 theta + + +-- !query 7 +SELECT ID, STRUCT(ST.C, ST.D).D NST FROM tbl_x +-- !query 7 schema +struct +-- !query 7 output +1 delta +2 eta +3 iota + + +-- !query 8 +SELECT ID, STRUCT(ST.C as STC, ST.D as STD).STD FROM tbl_x +-- !query 8 schema +struct +-- !query 8 output +1 delta +2 eta +3 iota diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index acd4ecf14617e..e2ee970d35f60 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -102,4 +102,4 @@ EXPLAIN select * from RaNgE(2) struct -- !query 8 output == Physical Plan == -*Range (0, 2, step=1, splits=None) +*Range (0, 2, step=1, splits=2) diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out new file mode 100644 index 0000000000000..aa5856138ed81 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -0,0 +1,204 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 11 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES +(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) +AS testData(val, cate) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData +ORDER BY cate, val +-- !query 1 schema +struct +-- !query 1 output +NULL NULL 0 +3 NULL 1 +NULL a 0 +1 a 1 +1 a 1 +2 a 1 +1 b 1 +2 b 1 +3 b 1 + + +-- !query 2 +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 2 schema +struct +-- !query 2 output +NULL NULL 3 +3 NULL 3 +NULL a 1 +1 a 2 +1 a 4 +2 a 4 +1 b 3 +2 b 6 +3 b 6 + + +-- !query 3 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData +ORDER BY cate, val +-- !query 3 schema +struct +-- !query 3 output +NULL NULL 0 +3 NULL 1 +NULL a 0 +1 a 2 +1 a 2 +2 a 3 +1 b 1 +2 b 2 +3 b 2 + + +-- !query 4 +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 4 schema +struct +-- !query 4 output +NULL NULL NULL +3 NULL 3 +NULL a NULL +1 a 4 +1 a 4 +2 a 2 +1 b 3 +2 b 5 +3 b 3 + + +-- !query 5 +SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 5 schema +struct +-- !query 5 output +NULL NULL NULL +3 NULL 3 +NULL a NULL +1 a 2 +1 a 2 +2 a 4 +1 b 1 +2 b 3 +3 b 5 + + +-- !query 6 +SELECT val, cate, +max(val) OVER w AS max, +min(val) OVER w AS min, +min(val) OVER w AS min, +count(val) OVER w AS count, +sum(val) OVER w AS sum, +avg(val) OVER w AS avg, +stddev(val) OVER w AS stddev, +first_value(val) OVER w AS first_value, +first_value(val, true) OVER w AS first_value_ignore_null, +first_value(val, false) OVER w AS first_value_contain_null, +last_value(val) OVER w AS last_value, +last_value(val, true) OVER w AS last_value_ignore_null, +last_value(val, false) OVER w AS last_value_contain_null, +rank() OVER w AS rank, +dense_rank() OVER w AS dense_rank, +cume_dist() OVER w AS cume_dist, +percent_rank() OVER w AS percent_rank, +ntile(2) OVER w AS ntile, +row_number() OVER w AS row_number, +var_pop(val) OVER w AS var_pop, +var_samp(val) OVER w AS var_samp, +approx_count_distinct(val) OVER w AS approx_count_distinct +FROM testData +WINDOW w AS (PARTITION BY cate ORDER BY val) +ORDER BY cate, val +-- !query 6 schema +struct +-- !query 6 output +NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 +3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 +NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 +1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 +2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 +1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 +2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 +3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 + + +-- !query 7 +SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val +-- !query 7 schema +struct +-- !query 7 output +NULL NULL NULL +3 NULL NULL +NULL a NULL +1 a NULL +1 a NULL +2 a NULL +1 b NULL +2 b NULL +3 b NULL + + +-- !query 8 +SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table; + + +-- !query 9 +SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val +-- !query 9 schema +struct +-- !query 9 output +NULL NULL 13 1.8571428571428572 +3 NULL 13 1.8571428571428572 +NULL a 13 1.8571428571428572 +1 a 13 1.8571428571428572 +1 a 13 1.8571428571428572 +2 a 13 1.8571428571428572 +1 b 13 1.8571428571428572 +2 b 13 1.8571428571428572 +3 b 13 1.8571428571428572 + + +-- !query 10 +SELECT val, cate, +first_value(false) OVER w AS first_value, +first_value(true, true) OVER w AS first_value_ignore_null, +first_value(false, false) OVER w AS first_value_contain_null, +last_value(false) OVER w AS last_value, +last_value(true, true) OVER w AS last_value_ignore_null, +last_value(false, false) OVER w AS last_value_contain_null +FROM testData +WINDOW w AS () +ORDER BY cate, val +-- !query 10 schema +struct +-- !query 10 output +NULL NULL false true false false true false +3 NULL false true false false true false +NULL a false true false false true false +1 a false true false false true false +1 a false true false false true false +2 a false true false false true false +1 b false true false false true false +2 b false true false false true false +3 b false true false false true false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 3e85d95523125..7e61a68025158 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { +import org.apache.spark.SparkConf - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") - super.beforeAll() - } +class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -38,12 +37,9 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo } class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { - - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -55,15 +51,14 @@ class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeA } } -class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with -BeforeAndAfter { +class TwoLevelAggregateHashMapWithVectorizedMapSuite + extends DataFrameAggregateSuite + with BeforeAndAfter { - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index e66fe97afad45..3ad526873f5d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -647,7 +647,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext withTable("t") { withTempPath { path => Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) - sql(s"CREATE TABLE t USING parquet LOCATION '$path'") + sql(s"CREATE TABLE t USING parquet LOCATION '${path.toURI}'") spark.catalog.cacheTable("t") spark.table("t").select($"i").cache() checkAnswer(spark.table("t").select($"i"), Row(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0f398dab7455..bc708ca88d7e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -39,6 +39,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) } + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -283,23 +286,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } - - test("=!=") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), Row(1, 1) :: Nil) @@ -321,7 +307,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData2.filter($"a" <=> null), Row(null) :: Nil) + } + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) } test(">") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb7df..f50c0cfcd00ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -186,6 +186,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") { + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping("CouRse"), grouping("year")), + Row("Java", 2012, 0, 0) :: + Row("Java", 2013, 0, 0) :: + Row("Java", null, 0, 1) :: + Row("dotNET", 2012, 0, 0) :: + Row("dotNET", 2013, 0, 0) :: + Row("dotNET", null, 0, 1) :: + Row(null, 2012, 1, 0) :: + Row(null, 2013, 1, 0) :: + Row(null, null, 1, 1) :: Nil + ) + } + test("rollup overlapping columns") { checkAnswer( testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), @@ -538,4 +554,27 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) ) } + + test("aggregate function in GROUP BY") { + val e = intercept[AnalysisException] { + testData.groupBy(sum($"key")).count() + } + assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) + } + + test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") { + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")), + Seq(Row(3, 4, 6, 7, 9))) + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), 'b, sum("b")), + Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6))) + + checkAnswer( + spark.sql("SELECT 3, 4, SUM(b) FROM testData2 GROUP BY 1, 2"), + Seq(Row(3, 4, 9))) + checkAnswer( + spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), + Seq(Row(3, 4, 9))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala new file mode 100644 index 0000000000000..60f6f23860ed9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameHintSuite extends PlanTest with SharedSQLContext { + import testImplicits._ + lazy val df = spark.range(10) + + private def check(df: Dataset[_], expected: LogicalPlan) = { + comparePlans( + df.queryExecution.logical, + expected + ) + } + + test("various hint parameters") { + check( + df.hint("hint1"), + UnresolvedHint("hint1", Seq(), + df.logicalPlan + ) + ) + + check( + df.hint("hint1", 1, "a"), + UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan) + ) + + check( + df.hint("hint1", 1, $"a"), + UnresolvedHint("hint1", Seq(1, $"a"), + df.logicalPlan + ) + ) + + check( + df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")), + UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")), + df.logicalPlan + ) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 541ffb58e727f..aef0d7f3e425b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } - test("broadcast join hint") { + test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") @@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } } + test("broadcast join hint using Dataset.hint") { + // make sure a giant join is not broadcastable + val plan1 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong), "id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // now with a hint it should be broadcasted + val plan2 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong).hint("broadcast"), "id") + .queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") @@ -248,4 +264,14 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { val ab = a.join(b, Seq("a"), "fullouter") checkAnswer(ab.join(c, "a"), Row(3, null, 4, 1) :: Nil) } + + test("SPARK-17685: WholeStageCodegenExec throws IndexOutOfBoundsException") { + val df = Seq((1, 1, "1"), (2, 2, "3")).toDF("int", "int2", "str") + val df2 = Seq((1, 1, "1"), (2, 3, "5")).toDF("int", "int2", "str") + val limit = 1310721 + val innerJoin = df.limit(limit).join(df2.limit(limit), Seq("int", "int2"), "inner") + .agg(count($"int")) + checkAnswer(innerJoin, Row(1) :: Nil) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 5e323c02b253d..45afbd29d1907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -185,6 +185,23 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + + test("SPARK-20430 Initialize Range parameters in a driver side") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) + } + } + + test("SPARK-21041 SparkSession.range()'s behavior is inconsistent with SparkContext.range()") { + val start = java.lang.Long.MAX_VALUE - 3 + val end = java.lang.Long.MIN_VALUE + 2 + Seq("false", "true").foreach { value => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value) { + assert(spark.range(start, end, 1).collect.length == 0) + assert(spark.range(start, start, 1).collect.length == 0) + } + } + } } object DataFrameRangeSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 97890a035a62f..dd118f88e3bb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val data = - sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - // Verify that the splits don't overlap - assert(splits(0).intersect(splits(1)).collect().isEmpty) + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52bd4e19f8952..3fa538c51bde6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1722,4 +1722,41 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { "Cannot have map type columns in DataFrame which calls set operations")) } } + + test("SPARK-20359: catalyst outer join optimization should not throw npe") { + val df1 = Seq("a", "b", "c").toDF("x") + .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) + val df2 = Seq("a", "b").toDF("x1") + df1 + .join(df2, df1("x") === df2("x1"), "left_outer") + .filter($"x1".isNotNull || !$"y".isin("a!")) + .count + } + + testQuietly("SPARK-19372: Filter can be executed w/o generated code due to JVM code size limit") { + val N = 400 + val rows = Seq(Row.fromSeq(Seq.fill(N)("string"))) + val schema = StructType(Seq.tabulate(N)(i => StructField(s"_c$i", StringType))) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) + + val filter = (0 until N) + .foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string")) + df.filter(filter).count + } + + test("SPARK-20897: cached self-join should not fail") { + // force to plan sort merge join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + val df = Seq(1 -> "a").toDF("i", "j") + val df1 = df.as("t1") + val df2 = df.as("t2") + assert(df1.join(df2, $"t1.i" === $"t2.i").cache().count() == 1) + } + } + + test("order-by ordinal.") { + checkAnswer( + testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), + Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c49104718..204858fa29787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types._ /** * Window function testing for DataFrame API. @@ -423,4 +424,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-21258: complex object in combination with spilling") { + // Make sure we trigger the spilling path. + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + val sampleSchema = new StructType(). + add("f0", StringType). + add("f1", LongType). + add("f2", ArrayType(new StructType(). + add("f20", StringType))). + add("f3", ArrayType(new StructType(). + add("f30", StringType))) + + val w0 = Window.partitionBy("f0").orderBy("f1") + val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue) + + val c0 = first(struct($"f2", $"f3")).over(w0) as "c0" + val c1 = last(struct($"f2", $"f3")).over(w1) as "c1" + + val input = + """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]} + |{"f1":1497802179638} + |{"f1":1497802189347} + |{"f1":1497802189593} + |{"f1":1497802189597} + |{"f1":1497802189599} + |{"f1":1497802192103} + |{"f1":1497802193414} + |{"f1":1497802193577} + |{"f1":1497802193709} + |{"f1":1497802202883} + |{"f1":1497802203006} + |{"f1":1497802203743} + |{"f1":1497802203834} + |{"f1":1497802203887} + |{"f1":1497802203893} + |{"f1":1497802203976} + |{"f1":1497820168098} + |""".stripMargin.split("\n").toSeq + + import testImplicits._ + + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 541565344f758..212ee1b39adf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -32,6 +32,9 @@ case class QueueClass(q: Queue[Int]) case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) +case class InnerData(name: String, value: Int) +case class NestedData(id: Int, param: Map[String, InnerData]) + package object packageobject { case class PackageClass(value: Int) } @@ -258,9 +261,19 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("nested sequences") { + checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) + checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) } + test("SPARK-19104: Lambda variables in ExternalMapToCatalyst should be global") { + val data = Seq.tabulate(10)(i => NestedData(1, Map("key" -> InnerData("name", i + 100)))) + val ds = spark.createDataset(data) + checkDataset(ds, data: _*) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 92c5656f65bb4..68f7de047b392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import com.esotericsoftware.kryo.{Kryo, Serializer} import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.test.TestSparkSession /** * Test suite to test Kryo custom registrators. @@ -30,12 +30,10 @@ import org.apache.spark.sql.test.TestSparkSession class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - /** - * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]]. - */ - protected override def beforeAll(): Unit = { - sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) - super.beforeAll() + + override protected def sparkConf: SparkConf = { + // Make sure we use the KryoRegistrator + super.sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) } test("Kryo registrator") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5b5cd28ad0c99..060fbe7cf81ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -21,11 +21,13 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -320,6 +322,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ((("b", 2), ("b", 2)), ("b", 2))) } + test("joinWith join types") { + val ds1 = Seq(1, 2, 3).toDS().as("a") + val ds2 = Seq(1, 2).toDS().as("b") + + val e1 = intercept[AnalysisException] { + ds1.joinWith(ds2, $"a.value" === $"b.value", "left_semi") + }.getMessage + assert(e1.contains("Invalid join type in joinWith: " + LeftSemi.sql)) + + val e2 = intercept[AnalysisException] { + ds1.joinWith(ds2, $"a.value" === $"b.value", "left_anti") + }.getMessage + assert(e2.contains("Invalid join type in joinWith: " + LeftAnti.sql)) + } + test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupByKey(v => (1, v._2)) @@ -1168,6 +1185,31 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(WithMapInOption(Some(Map(1 -> 1)))).toDS() checkDataset(ds, WithMapInOption(Some(Map(1 -> 1)))) } + + test("SPARK-20399: do not unescaped regex pattern when ESCAPED_STRING_LITERALS is enabled") { + withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { + val data = Seq("\u0020\u0021\u0023", "abc") + val df = data.toDF() + val rlike1 = df.filter("value rlike '^\\x20[\\x20-\\x23]+$'") + val rlike2 = df.filter($"value".rlike("^\\x20[\\x20-\\x23]+$")) + val rlike3 = df.filter("value rlike '^\\\\x20[\\\\x20-\\\\x23]+$'") + checkAnswer(rlike1, rlike2) + assert(rlike3.count() == 0) + } + } + + test("SPARK-21538: Attribute resolution inconsistency in Dataset API") { + val df = spark.range(3).withColumnRenamed("id", "x") + val expected = Row(0) :: Row(1) :: Row (2) :: Nil + checkAnswer(df.sort("id"), expected) + checkAnswer(df.sort(col("id")), expected) + checkAnswer(df.sort($"id"), expected) + checkAnswer(df.sort('id), expected) + checkAnswer(df.orderBy("id"), expected) + checkAnswer(df.orderBy(col("id")), expected) + checkAnswer(df.orderBy($"id"), expected) + checkAnswer(df.orderBy('id), expected) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index cef5bbf0e85a7..b9871afd59e4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -91,7 +91,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(explode_outer('intList)), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) } test("single posexplode") { @@ -105,7 +105,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(posexplode_outer('intList)), - Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil) } test("explode and other columns") { @@ -161,7 +161,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( df.select(explode('intList).as('int)).select(sum('int)), @@ -182,7 +182,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map)), - Row("a", "b") :: Row("c", "d") :: Nil) + Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil) } test("explode on map with aliases") { @@ -198,7 +198,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b") :: Nil) + Row("a", "b") :: Row(null, null) :: Nil) } test("self join explode") { @@ -279,7 +279,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( df2.selectExpr("inline_outer(col1)"), - Row(3, "4") :: Row(5, "6") :: Nil + Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a66aa85f5a02..cdfd33dfb91a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.language.existentials @@ -25,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} class JoinSuite extends QueryTest with SharedSQLContext { @@ -198,6 +200,14 @@ class JoinSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-22141: Propagate empty relation before checking Cartesian products") { + Seq("inner", "left", "right", "left_outer", "right_outer", "full_outer").foreach { joinType => + val x = testData2.where($"a" === 2 && !($"a" === 2)).as("x") + val y = testData2.where($"a" === 1 && !($"a" === 1)).as("y") + checkAnswer(x.join(y, Seq.empty, joinType), Nil) + } + } + test("big inner join, 4 matches per row") { val bigData = testData.union(testData).union(testData).union(testData) val bigDataX = bigData.as("x") @@ -665,7 +675,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("test SortMergeJoin (with spill)") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", - "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + "spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "0", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "1") { assertSpilled(sparkContext, "inner join") { checkAnswer( @@ -738,4 +749,22 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } } + + test("outer broadcast hash join should not throw NPE") { + withTempView("v1", "v2") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + Seq(2 -> 2).toDF("x", "y").createTempView("v1") + + spark.createDataFrame( + Seq(Row(1, "a")).asJava, + new StructType().add("i", "int", nullable = false).add("j", "string", nullable = false) + ).createTempView("v2") + + checkAnswer( + sql("select x, y, i, j from v1 left join v2 on x = i and y < length(j)"), + Row(2, 2, null, null) + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 8465e8d036a6d..989f8c23a4069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, struct, to_json} +import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -188,15 +188,33 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row("""{"_1":"26/08/2015 18:00"}""") :: Nil) } - test("to_json unsupported type") { + test("to_json - key types of map don't matter") { + // interval type is invalid for converting to JSON. However, the keys of a map are treated + // as strings, so its type doesn't matter. val df = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a") - .select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c")) + .select(struct(map($"a._1".cast(CalendarIntervalType), lit("a")).as("col1")).as("c")) + checkAnswer( + df.select(to_json($"c")), + Row("""{"col1":{"interval -3 months 7 hours":"a"}}""") :: Nil) + } + + test("to_json unsupported type") { + val baseDf = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a") + val df = baseDf.select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c")) val e = intercept[AnalysisException]{ // Unsupported type throws an exception df.select(to_json($"c")).collect() } assert(e.getMessage.contains( "Unable to convert column a of type calendarinterval to JSON.")) + + // interval type is invalid for converting to JSON. We can't use it as value type of a map. + val df2 = baseDf + .select(struct(map(lit("a"), $"a._1".cast(CalendarIntervalType)).as("col1")).as("c")) + val e2 = intercept[AnalysisException] { + df2.select(to_json($"c")).collect() + } + assert(e2.getMessage.contains("Unable to convert column col1 of type calendarinterval to JSON")) } test("roundtrip in to_json and from_json - struct") { @@ -274,7 +292,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } - assert(errMsg2.getMessage.contains("DataType invalidtype() is not supported")) + assert(errMsg2.getMessage.contains("DataType invalidtype is not supported")) val errMsg3 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 328c5395ec91e..c2d08a06569bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + + val bdPi: BigDecimal = BigDecimal(31415925L, 7) + checkAnswer( + sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " + + s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null)) + ) + + checkAnswer( + sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " + + s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) + ) } test("round/bround with data frame from a local Seq of Product") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala index 52c200796ce41..623a1b6f854cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -22,20 +22,22 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.streaming.ProcessingTime +import org.apache.spark.sql.streaming.{ProcessingTime, Trigger} class ProcessingTimeSuite extends SparkFunSuite { test("create") { - assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000) - assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000) - assert(ProcessingTime("1 minute").intervalMs === 60 * 1000) - assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000) - - intercept[IllegalArgumentException] { ProcessingTime(null: String) } - intercept[IllegalArgumentException] { ProcessingTime("") } - intercept[IllegalArgumentException] { ProcessingTime("invalid") } - intercept[IllegalArgumentException] { ProcessingTime("1 month") } - intercept[IllegalArgumentException] { ProcessingTime("1 year") } + def getIntervalMs(trigger: Trigger): Long = trigger.asInstanceOf[ProcessingTime].intervalMs + + assert(getIntervalMs(Trigger.ProcessingTime(10.seconds)) === 10 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime(10, TimeUnit.SECONDS)) === 10 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime("1 minute")) === 60 * 1000) + assert(getIntervalMs(Trigger.ProcessingTime("interval 1 minute")) === 60 * 1000) + + intercept[IllegalArgumentException] { Trigger.ProcessingTime(null: String) } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("invalid") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 month") } + intercept[IllegalArgumentException] { Trigger.ProcessingTime("1 year") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0dd9296a3f0ff..c6a6efda59879 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext +import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -2606,4 +2607,43 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) } } + + test("SPARK-12868: Allow adding jars from hdfs ") { + val jarFromHdfs = "hdfs://doesnotmatter/test.jar" + val jarFromInvalidFs = "fffs://doesnotmatter/test.jar" + + // if 'hdfs' is not supported, MalformedURLException will be thrown + new URL(jarFromHdfs) + + intercept[MalformedURLException] { + new URL(jarFromInvalidFs) + } + } + + test("RuntimeReplaceable functions should not take extra parameters") { + val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)")) + assert(e.message.contains("Invalid number of arguments")) + } + + test("SPARK-21228: InSet incorrect handling of structs") { + withTempView("A") { + // reduce this from the default of 10 so the repro query text is not too long + withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) { + // a relation that has 1 column of struct type with values (1,1), ..., (9, 9) + spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a") + .createOrReplaceTempView("A") + val df = sql( + """ + |SELECT * from + | (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows + | -- the IN will become InSet with a Set of GenericInternalRows + | -- a GenericInternalRow is never equal to an UnsafeRow so the query would + | -- returns 0 results, which is incorrect + | WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L), + | NAMED_STRUCT('a', 3L, 'b', 3L)) + """.stripMargin) + checkAnswer(df, Row(Row(1, 1))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 386d13d07a95f..1c6afa5e26e14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -17,49 +17,48 @@ package org.apache.spark.sql +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} /** * Test cases for the builder pattern of [[SparkSession]]. */ -class SparkSessionBuilderSuite extends SparkFunSuite { +class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { - private var initialSession: SparkSession = _ + override def afterEach(): Unit = { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + } - private lazy val sparkContext: SparkContext = { - initialSession = SparkSession.builder() + test("create with config options and propagate them to SparkContext and SparkSession") { + val session = SparkSession.builder() .master("local") .config("spark.ui.enabled", value = false) .config("some-config", "v2") .getOrCreate() - initialSession.sparkContext - } - - test("create with config options and propagate them to SparkContext and SparkSession") { - // Creating a new session with config - this works by just calling the lazy val - sparkContext - assert(initialSession.sparkContext.conf.get("some-config") == "v2") - assert(initialSession.conf.get("some-config") == "v2") - SparkSession.clearDefaultSession() + assert(session.sparkContext.conf.get("some-config") == "v2") + assert(session.conf.get("some-config") == "v2") } test("use global default session") { - val session = SparkSession.builder().getOrCreate() + val session = SparkSession.builder().master("local").getOrCreate() assert(SparkSession.builder().getOrCreate() == session) - SparkSession.clearDefaultSession() } test("config options are propagated to existing SparkSession") { - val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate() + val session1 = SparkSession.builder().master("local").config("spark-config1", "a").getOrCreate() assert(session1.conf.get("spark-config1") == "a") val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() assert(session1 == session2) assert(session1.conf.get("spark-config1") == "b") - SparkSession.clearDefaultSession() } test("use session from active thread session and propagate config options") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() val activeSession = defaultSession.newSession() SparkSession.setActiveSession(activeSession) val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() @@ -70,16 +69,14 @@ class SparkSessionBuilderSuite extends SparkFunSuite { SparkSession.clearActiveSession() assert(SparkSession.builder().getOrCreate() == defaultSession) - SparkSession.clearDefaultSession() } test("create a new session if the default session has been stopped") { - val defaultSession = SparkSession.builder().getOrCreate() + val defaultSession = SparkSession.builder().master("local").getOrCreate() SparkSession.setDefaultSession(defaultSession) defaultSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != defaultSession) - newSession.stop() } test("create a new session if the active thread session has been stopped") { @@ -88,11 +85,9 @@ class SparkSessionBuilderSuite extends SparkFunSuite { activeSession.stop() val newSession = SparkSession.builder().master("local").getOrCreate() assert(newSession != activeSession) - newSession.stop() } test("create SparkContext first then SparkSession") { - sparkContext.stop() val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") val sparkContext2 = new SparkContext(conf) val session = SparkSession.builder().config("key2", "value2").getOrCreate() @@ -101,14 +96,12 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.sparkContext.conf.get("key1") == "value1") assert(session.sparkContext.conf.get("key2") == "value2") assert(session.sparkContext.conf.get("spark.app.name") == "test") - session.stop() } test("SPARK-15887: hive-site.xml should be loaded") { val session = SparkSession.builder().master("local").getOrCreate() assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true") - session.stop() } test("SPARK-15991: Set global Hadoop conf") { @@ -120,7 +113,6 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue) } finally { session.sparkContext.hadoopConfiguration.unset(mySpecialKey) - session.stop() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala new file mode 100644 index 0000000000000..43db79663322a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Test cases for the [[SparkSessionExtensions]]. + */ +class SparkSessionExtensionSuite extends SparkFunSuite { + type ExtensionsBuilder = SparkSessionExtensions => Unit + private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { + val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("inject analyzer rule") { + withSession(_.injectResolutionRule(MyRule)) { session => + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } + } + + test("inject check analysis rule") { + withSession(_.injectCheckRule(MyCheckRule)) { session => + assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) + } + } + + test("inject optimizer rule") { + withSession(_.injectOptimizerRule(MyRule)) { session => + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + } + } + + test("inject spark planner strategy") { + withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + } + } + + test("inject parser") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + } + withSession(extension) { session => + assert(session.sessionState.sqlParser == CatalystSqlParser) + } + } + + test("inject stacked parsers") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser(MyParser) + extensions.injectParser(MyParser) + } + withSession(extension) { session => + val parser = MyParser(session, MyParser(session, CatalystSqlParser)) + assert(session.sessionState.sqlParser == parser) + } + } + + test("use custom class for extensions") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } finally { + stop(session) + } + } +} + +case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan +} + +case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { + override def apply(plan: LogicalPlan): Unit = { } +} + +case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty +} + +case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { + override def parsePlan(sqlText: String): LogicalPlan = + delegate.parsePlan(sqlText) + + override def parseExpression(sqlText: String): Expression = + delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + delegate.parseDataType(sqlText) +} + +class MyExtensions extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectPlannerStrategy(MySparkStrategy) + e.injectResolutionRule(MyRule) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index ddc393c8da053..86d19af9dd548 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -40,17 +40,6 @@ import org.apache.spark.sql.types._ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with SharedSQLContext { import testImplicits._ - private def checkTableStats(tableName: String, expectedRowCount: Option[Int]) - : Option[CatalogStatistics] = { - val df = spark.table(tableName) - val stats = df.queryExecution.analyzed.collect { case rel: LogicalRelation => - assert(rel.catalogTable.get.stats.flatMap(_.rowCount) === expectedRowCount) - rel.catalogTable.get.stats - } - assert(stats.size == 1) - stats.head - } - test("estimates the size of a limit 0 on outer join") { withTempView("test") { Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") @@ -88,6 +77,19 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("analyze empty table") { + val table = "emptyTable" + withTable(table) { + sql(s"CREATE TABLE $table (key STRING, value STRING) USING PARQUET") + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS noscan") + val fetchedStats1 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = None) + assert(fetchedStats1.get.sizeInBytes == 0) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + val fetchedStats2 = checkTableStats(table, hasSizeInBytes = true, expectedRowCounts = Some(0)) + assert(fetchedStats2.get.sizeInBytes == 0) + } + } + test("test table-level statistics for data source table") { val tableName = "tbl" withTable(tableName) { @@ -96,11 +98,11 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared // noscan won't count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") - checkTableStats(tableName, expectedRowCount = None) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = None) // without noscan, we count the number of rows sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") - checkTableStats(tableName, expectedRowCount = Some(2)) + checkTableStats(tableName, hasSizeInBytes = true, expectedRowCounts = Some(2)) } } @@ -164,7 +166,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared numbers.foreach { case (input, (expectedSize, expectedRows)) => val stats = Statistics(sizeInBytes = input, rowCount = Some(input)) val expectedString = s"sizeInBytes=$expectedSize, rowCount=$expectedRows," + - s" isBroadcastable=${stats.isBroadcastable}" + s" hints=none" assert(stats.simpleString == expectedString) } } @@ -219,6 +221,23 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils private val randomName = new Random(31) + def checkTableStats( + tableName: String, + hasSizeInBytes: Boolean, + expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { + val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats + + if (hasSizeInBytes || expectedRowCounts.nonEmpty) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + assert(stats.get.rowCount === expectedRowCounts) + } else { + assert(stats.isEmpty) + } + + stats + } + /** * Compute column stats for the given DataFrame and compare it with colStats. */ @@ -285,7 +304,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils // Analyze only one column. sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS FOR COLUMNS c1") val (relation, catalogTable) = spark.table(tableName).queryExecution.analyzed.collect { - case catalogRel: CatalogRelation => (catalogRel, catalogRel.tableMeta) + case catalogRel: HiveTableRelation => (catalogRel, catalogRel.tableMeta) case logicalRel: LogicalRelation => (logicalRel, logicalRel.catalogTable.get) }.head val emptyColStat = ColumnStat(0, None, None, 0, 4, 4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 0f0199cbe2777..2a3bdfbfa0108 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -72,7 +72,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } - test("rdd deserialization does not crash [SPARK-15791]") { + test("SPARK-15791: rdd deserialization does not crash") { sql("select (select 1 as b) as b").rdd.count() } @@ -854,4 +854,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"), Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil) } + + test("SPARK-20688: correctly check analysis for scalar sub-queries") { + withTempView("t") { + Seq(1 -> "a").toDF("i", "j").createTempView("t") + val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)")) + assert(e.message.contains("cannot resolve '`a`' given input columns: [i, j]")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ae6b2bc3753fb..6f8723af91cea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -93,6 +93,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } + test("UDF defined using UserDefinedFunction") { + import functions.udf + val foo = udf((x: Int) => x + 1) + spark.udf.register("foo", foo) + assert(sql("select foo(5)").head().getInt(0) == 6) + } + test("ZeroArgument UDF") { spark.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a32763db054f3..a5f904c621e6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -101,9 +101,22 @@ class UnsafeRowSuite extends SparkFunSuite { MemoryAllocator.UNSAFE.free(offheapRowPage) } } + val (bytesFromArrayBackedRowWithOffset, field0StringFromArrayBackedRowWithOffset) = { + val baos = new ByteArrayOutputStream() + val numBytes = arrayBackedUnsafeRow.getSizeInBytes + val bytesWithOffset = new Array[Byte](numBytes + 100) + System.arraycopy(arrayBackedUnsafeRow.getBaseObject.asInstanceOf[Array[Byte]], 0, + bytesWithOffset, 100, numBytes) + val arrayBackedRow = new UnsafeRow(arrayBackedUnsafeRow.numFields()) + arrayBackedRow.pointTo(bytesWithOffset, Platform.BYTE_ARRAY_OFFSET + 100, numBytes) + arrayBackedRow.writeToStream(baos, null) + (baos.toByteArray, arrayBackedRow.getString(0)) + } assert(bytesFromArrayBackedRow === bytesFromOffheapRow) assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) + assert(bytesFromArrayBackedRow === bytesFromArrayBackedRowWithOffset) + assert(field0StringFromArrayBackedRow === field0StringFromArrayBackedRowWithOffset) } test("calling getDouble() and getFloat() on null columns") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 05a2b2c862c73..423e1288e8dcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -18,22 +18,17 @@ package org.apache.spark.sql.execution import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils /** * Suite that tests the redaction of DataSourceScanExec */ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { - import Utils._ - - override def beforeAll(): Unit = { - sparkConf.set("spark.redaction.string.regex", - "file:/[\\w_]+") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.string.regex", "file:/[\\w_]+") test("treeString is redacted") { withTempDir { dir => @@ -43,7 +38,7 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head - assert(rootPath.toString.contains(basePath.toString)) + assert(rootPath.toString.contains(dir.toURI.getPath.stripSuffix("/"))) assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName)) assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 06bce9a2400e7..0de1832a0fed9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -252,7 +252,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { def withSparkSession( f: SparkSession => Unit, targetNumPostShufflePartitions: Int, - minNumPostShufflePartitions: Option[Int]): Unit = { + minNumPostShufflePartitions: Option[Int], + adaptiveExecutionDisabledForJoin: Boolean = false): Unit = { val sparkConf = new SparkConf(false) .setMaster("local[*]") @@ -265,6 +266,9 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .set( SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, targetNumPostShufflePartitions.toString) + if (adaptiveExecutionDisabledForJoin) { + sparkConf.set(SQLConf.ADAPTIVE_EXECUTION_DISABLED_FOR_JOINING.key, "true") + } minNumPostShufflePartitions match { case Some(numPartitions) => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, numPartitions.toString) @@ -477,7 +481,86 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSparkSession(test, 6144, minNumPostShufflePartitions) + withSparkSession( + test, + 6144, + minNumPostShufflePartitions, + adaptiveExecutionDisabledForJoin = false + ) } + /* + test(s"adaptive execution disabled for joining: complex query 3$testNameNote") { + val test = { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + .groupBy("key1") + .count() + .toDF("key1", "cnt1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = + df1 + .join(df2, col("key1") === col("key2")) + .select(col("key1"), col("cnt1"), col("value2")) + + // Check the answer first. + val expectedAnswer = + spark + .range(0, 1000) + .selectExpr("id % 500 as key", "2 as cnt", "id as value") + checkAnswer( + join, + expectedAnswer.collect()) + + // Verify that adaptive execution is disabled for SortMergeJoin but not for others + val executedPlan = join.queryExecution.executedPlan + val verified = new mutable.HashSet[SparkPlan] + val waitingToVerify = new mutable.Stack[(SparkPlan, Boolean)] + def verify(plan: SparkPlan, childOfJoin: Boolean): Unit = { + if (!verified(plan)) { + verified += plan + plan match { + case SortMergeJoinExec(_, _, _, _, left, right) => + waitingToVerify.push((left, true)) + waitingToVerify.push((right, true)) + case ShuffleExchange(newPartitioning, child, coordinator) => + if (childOfJoin) { + assert(coordinator.isEmpty) + waitingToVerify.push((child, false)) + } else { + minNumPostShufflePartitions match { + case Some(_) => + assert(coordinator.isDefined) + assert(newPartitioning.numPartitions === 3) + case None => + assert(coordinator.isDefined) + } + waitingToVerify.push((child, false)) + } + case _ => + plan.children.foreach { child => + waitingToVerify.push((child, childOfJoin)) + } + } + } + } + waitingToVerify.push((executedPlan, false)) + while(waitingToVerify.nonEmpty) { + val (plan, childOfJoin) = waitingToVerify.pop() + verify(plan, childOfJoin) + } + } + + withSparkSession( + test, + 6144, + minNumPostShufflePartitions, + adaptiveExecutionDisabledForJoin = true) + } */ } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 00c5f2550cbb1..a5adc3639ad64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -67,7 +67,10 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + rows.foreach(x => array.add(x)) val iterator = array.generateIterator() @@ -143,7 +146,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) rows.foreach(x => array.add(x)) val iterator = array.generateIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index 53c41639942b4..ecc7264d79442 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -31,7 +31,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar override def afterAll(): Unit = TaskContext.unset() - private def withExternalArray(spillThreshold: Int) + private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) @@ -45,6 +45,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar taskContext, 1024, SparkEnv.get.memoryManager.pageSizeBytes, + inMemoryThreshold, spillThreshold) try f(array) finally { array.clear() @@ -109,9 +110,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar assert(getNumBytesSpilled > 0) } - test("insert rows less than the spillThreshold") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => + test("insert rows less than the inMemoryThreshold") { + val (inMemoryThreshold, spillThreshold) = (100, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) val expectedValues = populateRows(array, 1) @@ -122,8 +123,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) // Verify that NO spill has happened - populateRows(array, spillThreshold - 1, expectedValues) - assert(array.length == spillThreshold) + populateRows(array, inMemoryThreshold - 1, expectedValues) + assert(array.length == inMemoryThreshold) assertNoSpill() val iterator2 = validateData(array, expectedValues) @@ -133,20 +134,42 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } } - test("insert rows more than the spillThreshold to force spill") { - val spillThreshold = 100 - withExternalArray(spillThreshold) { array => - val numValuesInserted = 20 * spillThreshold - + test("insert rows more than the inMemoryThreshold but less than spillThreshold") { + val (inMemoryThreshold, spillThreshold) = (10, 50) + withExternalArray(inMemoryThreshold, spillThreshold) { array => assert(array.isEmpty) - val expectedValues = populateRows(array, 1) - assert(array.length == 1) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a + // spill to happen. Verify that NO spill has happened + populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) + assert(array.length == spillThreshold - 1) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("insert rows enough to force spill") { + val (inMemoryThreshold, spillThreshold) = (20, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) val iterator1 = validateData(array, expectedValues) + assertNoSpill() - // Populate more rows to trigger spill. Verify that spill has happened - populateRows(array, numValuesInserted - 1, expectedValues) - assert(array.length == numValuesInserted) + // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. + // Verify that spill has happened + populateRows(array, 2, expectedValues) + assert(array.length == inMemoryThreshold + 1) assertSpill() val iterator2 = validateData(array, expectedValues) @@ -158,7 +181,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("iterator on an empty array should be empty") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array => val iterator = array.generateIterator() assert(array.isEmpty) assert(array.length == 0) @@ -167,7 +190,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with negative start index") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array => val exception = intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) @@ -178,8 +201,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (without spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold / 2) val exception = @@ -191,8 +214,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with start index exceeding array's size (with spill)") { - val spillThreshold = 2 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => populateRows(array, spillThreshold * 2) val exception = @@ -205,10 +228,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - val expectedValues = populateRows(array, spillThreshold) - val startIndex = spillThreshold / 2 + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + val expectedValues = populateRows(array, inMemoryThreshold) + val startIndex = inMemoryThreshold / 2 val iterator = array.generateIterator(startIndex = startIndex) for (i <- startIndex until expectedValues.length) { checkIfValueExists(iterator, expectedValues(i)) @@ -217,8 +240,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("generate iterator with custom start index (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => val expectedValues = populateRows(array, spillThreshold * 10) val startIndex = spillThreshold * 2 val iterator = array.generateIterator(startIndex = startIndex) @@ -229,7 +252,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (without spill)") { - withExternalArray(spillThreshold = 10) { array => + withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array => // insert 2 rows, iterate until the first row populateRows(array, 2) @@ -254,9 +277,9 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("test iterator invalidation (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => - // Populate enough rows so that spill has happens + val (inMemoryThreshold, spillThreshold) = (2, 10) + withExternalArray(inMemoryThreshold, spillThreshold) { array => + // Populate enough rows so that spill happens populateRows(array, spillThreshold * 2) assertSpill() @@ -281,7 +304,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear on an empty the array") { - withExternalArray(spillThreshold = 2) { array => + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array => val iterator = array.generateIterator() assert(!iterator.hasNext) @@ -299,10 +322,10 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar } test("clear array (without spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 100) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate rows ... but not enough to trigger spill - populateRows(array, spillThreshold / 2) + populateRows(array, inMemoryThreshold / 2) assertNoSpill() // Clear the array @@ -311,21 +334,21 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar // Re-populate few rows so that there is no spill // Verify the data. Verify that there was no spill - val expectedValues = populateRows(array, spillThreshold / 3) + val expectedValues = populateRows(array, inMemoryThreshold / 2) validateData(array, expectedValues) assertNoSpill() // Populate more rows .. enough to not trigger a spill. // Verify the data. Verify that there was no spill - populateRows(array, spillThreshold / 3, expectedValues) + populateRows(array, inMemoryThreshold / 2, expectedValues) validateData(array, expectedValues) assertNoSpill() } } test("clear array (with spill)") { - val spillThreshold = 10 - withExternalArray(spillThreshold) { array => + val (inMemoryThreshold, spillThreshold) = (10, 20) + withExternalArray(inMemoryThreshold, spillThreshold) { array => // Populate enough rows to trigger spill populateRows(array, spillThreshold * 2) val bytesSpilled = getNumBytesSpilled diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala index 5c63c6a414f93..a3d75b221ec3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala @@ -35,39 +35,47 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { private var globalTempDB: String = _ test("basic semantic") { - sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + try { + sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1, 'a'") + + // If there is no database in table name, we should try local temp view first, if not found, + // try table/view in current database, which is "default" in this case. So we expect + // NoSuchTableException here. + intercept[NoSuchTableException](spark.table("src")) - // If there is no database in table name, we should try local temp view first, if not found, - // try table/view in current database, which is "default" in this case. So we expect - // NoSuchTableException here. - intercept[NoSuchTableException](spark.table("src")) + // Use qualified name to refer to the global temp view explicitly. + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) - // Use qualified name to refer to the global temp view explicitly. - checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + // Table name without database will never refer to a global temp view. + intercept[NoSuchTableException](sql("DROP VIEW src")) - // Table name without database will never refer to a global temp view. - intercept[NoSuchTableException](sql("DROP VIEW src")) + sql(s"DROP VIEW $globalTempDB.src") + // The global temp view should be dropped successfully. + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) - sql(s"DROP VIEW $globalTempDB.src") - // The global temp view should be dropped successfully. - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + // We can also use Dataset API to create global temp view + Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) - // We can also use Dataset API to create global temp view - Seq(1 -> "a").toDF("i", "j").createGlobalTempView("src") - checkAnswer(spark.table(s"$globalTempDB.src"), Row(1, "a")) + // Use qualified name to rename a global temp view. + sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) - // Use qualified name to rename a global temp view. - sql(s"ALTER VIEW $globalTempDB.src RENAME TO src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src")) - checkAnswer(spark.table(s"$globalTempDB.src2"), Row(1, "a")) + // Use qualified name to alter a global temp view. + sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") + checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) - // Use qualified name to alter a global temp view. - sql(s"ALTER VIEW $globalTempDB.src2 AS SELECT 2, 'b'") - checkAnswer(spark.table(s"$globalTempDB.src2"), Row(2, "b")) + // We can also use Catalog API to drop global temp view + spark.catalog.dropGlobalTempView("src2") + intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) - // We can also use Catalog API to drop global temp view - spark.catalog.dropGlobalTempView("src2") - intercept[NoSuchTableException](spark.table(s"$globalTempDB.src2")) + // We can also use Dataset API to replace global temp view + Seq(2 -> "b").toDF("i", "j").createOrReplaceGlobalTempView("src") + checkAnswer(spark.table(s"$globalTempDB.src"), Row(2, "b")) + } finally { + spark.catalog.dropGlobalTempView("src") + } } test("global temp view is shared among all sessions") { @@ -106,7 +114,7 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext { test("CREATE TABLE LIKE should work for global temp view") { try { sql("CREATE GLOBAL TEMP VIEW src AS SELECT 1 AS a, '2' AS b") - sql(s"CREATE TABLE cloned LIKE ${globalTempDB}.src") + sql(s"CREATE TABLE cloned LIKE $globalTempDB.src") val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier("cloned")) assert(tableMeta.schema == new StructType().add("a", "int", false).add("b", "string", false)) } finally { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala index 58c310596ca6d..6c66902127d03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala @@ -117,4 +117,12 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext { "select partcol1, max(partcol2) from srcpart where partcol1 = 0 group by rollup (partcol1)", "select partcol2 from (select partcol2 from srcpart where partcol1 = 0 union all " + "select partcol2 from srcpart where partcol1 = 1) t group by partcol2") + + test("SPARK-21884 Fix StackOverflowError on MetadataOnlyQuery") { + withTable("t_1000") { + sql("CREATE TABLE t_1000 (a INT, p INT) USING PARQUET PARTITIONED BY (p)") + (1 to 1000).foreach(p => sql(s"ALTER TABLE t_1000 ADD PARTITION (p=$p)")) + sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4d155d538d637..63e17c7f372b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -513,26 +513,30 @@ class PlannerSuite extends SharedSQLContext { } test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { - val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) - // Both left and right keys should be sorted after the SMJ. - Seq(orderingA, orderingB).foreach { ordering => - assertSortRequirementsAreSatisfied( - childPlan = innerSmj, - requiredOrdering = Seq(ordering), - shouldHaveSort = false) + Seq(Inner, Cross).foreach { joinType => + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) + // Both left and right keys should be sorted after the SMJ. + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = innerSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } } } test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + "child SMJ") { - val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) - val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC) - // After the second SMJ, exprA, exprB and exprC should all be sorted. - Seq(orderingA, orderingB, orderingC).foreach { ordering => - assertSortRequirementsAreSatisfied( - childPlan = parentSmj, - requiredOrdering = Seq(ordering), - shouldHaveSort = false) + Seq(Inner, Cross).foreach { joinType => + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, None, childSmj, planC) + // After the second SMJ, exprA, exprB and exprC should all be sorted. + Seq(orderingA, orderingB, orderingC).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = parentSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 1c1931b6a6daf..afccbe5cc6d19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -16,37 +16,36 @@ */ package org.apache.spark.sql.execution -import java.util.Locale - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext class QueryExecutionSuite extends SharedSQLContext { test("toString() exception/error handling") { - val badRule = new SparkStrategy { - var mode: String = "" - override def apply(plan: LogicalPlan): Seq[SparkPlan] = - mode.toLowerCase(Locale.ROOT) match { - case "exception" => throw new AnalysisException(mode) - case "error" => throw new Error(mode) - case _ => Nil - } - } - spark.experimental.extraStrategies = badRule :: Nil + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil + }) def qe: QueryExecution = new QueryExecution(spark, OneRowRelation) // Nothing! - badRule.mode = "" assert(qe.toString.contains("OneRowRelation")) // Throw an AnalysisException - this should be captured. - badRule.mode = "exception" + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + throw new AnalysisException("exception") + }) assert(qe.toString.contains("org.apache.spark.sql.AnalysisException")) // Throw an Error - this should not be captured. - badRule.mode = "error" + spark.experimental.extraStrategies = Seq( + new SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = + throw new Error("error") + }) val error = intercept[Error](qe.toString) assert(error.getMessage.contains("error")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index d32716c18ddfb..6761f05bb462a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -669,4 +669,14 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { "positive.")) } } + + test("permanent view should be case-preserving") { + withView("v") { + sql("CREATE VIEW v AS SELECT 1 as aBc") + assert(spark.table("v").schema.head.name == "aBc") + + sql("CREATE OR REPLACE VIEW v AS SELECT 2 as cBa") + assert(spark.table("v").schema.head.name == "cBa") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index 52e4f047225de..a57514c256b90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -356,6 +356,46 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { spark.catalog.dropTempView("nums") } + test("window function: mutiple window expressions specified by range in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + withTempView("nums") { + val expected = + Row(1, 1, 1, 4, null, 8, 25) :: + Row(1, 3, 4, 9, 1, 12, 24) :: + Row(1, 5, 9, 15, 4, 16, 21) :: + Row(1, 7, 16, 21, 8, 9, 16) :: + Row(1, 9, 25, 16, 12, null, 9) :: + Row(0, 2, 2, 6, null, 10, 30) :: + Row(0, 4, 6, 12, 2, 14, 28) :: + Row(0, 6, 12, 18, 6, 18, 24) :: + Row(0, 8, 20, 24, 10, 10, 18) :: + Row(0, 10, 30, 18, 14, null, 10) :: + Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) over w1 as history_sum, + | sum(x) over w2 as period_sum1, + | sum(x) over w3 as period_sum2, + | sum(x) over w4 as period_sum3, + | sum(x) over w5 as future_sum + |FROM nums + |WINDOW + | w1 AS (PARTITION BY y ORDER BY x RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), + | w2 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING), + | w3 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 4 PRECEDING AND 2 PRECEDING ), + | w4 AS (PARTITION BY y ORDER BY x RANGE BETWEEN 2 FOLLOWING AND 4 FOLLOWING), + | w5 AS (PARTITION BY y ORDER BY x RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) + """.stripMargin + ) + checkAnswer(actual, expected) + } + } + test("SPARK-7595: Window will cause resolve failed with self join") { checkAnswer(sql( """ @@ -437,7 +477,8 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) """.stripMargin) - withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + withSQLConf("spark.sql.windowExec.buffer.in.memory.threshold" -> "1", + "spark.sql.windowExec.buffer.spill.threshold" -> "2") { assertSpilled(sparkContext, "test with low buffer spill threshold") { checkAnswer(actual, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala new file mode 100644 index 0000000000000..aaf51b5b90111 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Tests for the sameResult function for [[SparkPlan]]s. + */ +class SameResultSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("FileSourceScanExec: different orders of data filters and partition filters") { + withTempPath { path => + val tmpDir = path.getCanonicalPath + spark.range(10) + .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d") + .write + .partitionBy("a", "b") + .parquet(tmpDir) + val df = spark.read.parquet(tmpDir) + // partition filters: a > 1 AND b < 9 + // data filters: c > 1 AND d < 9 + val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9")) + val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1")) + assert(plan1.sameResult(plan2)) + } + } + + private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = { + df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + .asInstanceOf[FileSourceScanExec] + } + + test("SPARK-20725: partial aggregate should behave correctly for sameResult") { + val df1 = spark.range(10).agg(sum($"id")) + val df2 = spark.range(10).agg(sum($"id")) + assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan)) + + val df3 = spark.range(10).agg(sumDistinct($"id")) + val df4 = spark.range(10).agg(sumDistinct($"id")) + assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a4b30a2f8cec1..183c68fd3c016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -22,8 +22,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -127,4 +129,24 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { "named_struct('a',id+2, 'b',id+2) as col2") .filter("col1 = col2").count() } + + test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { + import testImplicits._ + + val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int") + val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str") + + val df = df1.join(df2, df1("key") === df2("key")) + .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1") + .select("int") + + val plan = df.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.children(0) + .isInstanceOf[SortMergeJoinExec]).isDefined) + assert(df.collect() === Array(Row(1), Row(2))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8a2993bdf4b28..8a798fb444696 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,6 +107,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -148,6 +149,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -187,6 +189,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -225,6 +228,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -273,6 +277,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 239822b72034a..a6249ce021400 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -43,6 +43,7 @@ object TPCDSQueryBenchmark { .set("spark.driver.memory", "3g") .set("spark.executor.memory", "3g") .set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) + .set("spark.sql.crossJoin.enabled", "true") val spark = SparkSession.builder.config(conf).getOrCreate() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 1e6a6a8ba3362..109b1d9db60d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -414,4 +414,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) } } + + test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") { + withSQLConf("spark.sql.shuffle.partitions" -> "200") { + val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group") + val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id") + val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct() + + df3.unpersist() + val agg_without_cache = df3.groupBy($"item").count() + + df3.cache() + val agg_with_cache = df3.groupBy($"item").count() + checkAnswer(agg_without_cache, agg_with_cache) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 97c61dc8694bc..8a6bc62fec96c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -530,13 +530,13 @@ class DDLCommandSuite extends PlanTest { """.stripMargin val sql4 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDE 'org.apache.class' WITH SERDEPROPERTIES ('columns'='foo,bar', |'field.delim' = ',') """.stripMargin val sql5 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed1 = parser.parsePlan(sql1) @@ -558,12 +558,12 @@ class DDLCommandSuite extends PlanTest { tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) val expected5 = AlterTableSerDePropertiesCommand( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -832,6 +832,14 @@ class DDLCommandSuite extends PlanTest { assert(e.contains("Found duplicate keys 'a'")) } + test("empty values in non-optional partition specs") { + val e = intercept[ParseException] { + parser.parsePlan( + "SHOW PARTITIONS dbx.tab1 PARTITION (a='1', b)") + }.getMessage + assert(e.contains("Found an empty partition key 'b'")) + } + test("drop table") { val tableName1 = "db.tab" val tableName2 = "tab" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fe74ab49f91bd..5109c649f4318 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -49,7 +49,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() @@ -70,46 +71,6 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo tracksPartitionsInCatalog = true) } - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - - test("alter table: change column (datasource table)") { - testChangeColumn(isDatasourceTable = true) - } - - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - test("create a managed Hive source table") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") val tabName = "tbl" @@ -163,7 +124,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" } - protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable + protected def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -205,8 +169,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { ignoreIfExists = false) } - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { - catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) + private def createTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( @@ -223,6 +190,46 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -695,7 +702,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withView("testview") { sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1 String, c2 String) USING " + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + - s"OPTIONS (PATH '$tmpFile')") + s"OPTIONS (PATH '${tmpFile.toURI}')") checkAnswer( sql("select c1, c2 from testview order by c1 limit 1"), @@ -707,7 +714,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TEMPORARY VIEW testview |USING org.apache.spark.sql.execution.datasources.csv.CSVFileFormat - |OPTIONS (PATH '$tmpFile') + |OPTIONS (PATH '${tmpFile.toURI}') """.stripMargin) } } @@ -751,7 +758,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val df = (1 to 2).map { i => (i, i.toString) }.toDF("age", "name") df.write.insertInto("students") spark.catalog.cacheTable("students") - assume(spark.table("students").collect().toSeq == df.collect().toSeq, "bad test: wrong data") + checkAnswer(spark.table("students"), df) assume(spark.catalog.isCached("students"), "bad test: table was not cached in the first place") sql("ALTER TABLE students RENAME TO teachers") sql("CREATE TABLE students (age INT, name STRING) USING parquet") @@ -760,7 +767,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(!spark.catalog.isCached("students")) assert(spark.catalog.isCached("teachers")) assert(spark.table("students").collect().isEmpty) - assert(spark.table("teachers").collect().toSeq == df.collect().toSeq) + checkAnswer(spark.table("teachers"), df) } test("rename temporary table - destination table with database name") { @@ -835,32 +842,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("alter table: set location") { - testSetLocation(isDatasourceTable = false) - } - - test("alter table: set properties") { - testSetProperties(isDatasourceTable = false) - } - - test("alter table: unset properties") { - testUnsetProperties(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde") { - testSetSerde(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde partition") { - testSetSerdePartition(isDatasourceTable = false) - } - - test("alter table: change column") { - testChangeColumn(isDatasourceTable = false) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -885,10 +866,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - test("alter table: add partition") { - testAddPartitions(isDatasourceTable = false) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -957,17 +934,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } - test("alter table: drop partition") { - testDropPartitions(isDatasourceTable = false) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } - test("alter table: rename partition") { - testRenamePartitions(isDatasourceTable = false) - } test("show databases") { sql("CREATE DATABASE showdb2B") @@ -1011,18 +981,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(catalog.listTables("default") == Nil) } - test("drop table") { - testDropTable(isDatasourceTable = false) - } - protected def testDropTable(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) assert(catalog.listTables("dbx") == Seq(tableIdent)) sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) @@ -1046,22 +1012,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } - private def convertToDatasourceTable( - catalog: SessionCatalog, - tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - provider = Some("csv"))) - assert(catalog.getTableMetadata(tableIdent).provider == Some("csv")) - } - protected def testSetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1084,13 +1042,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1121,15 +1079,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetLocation(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, partSpec, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) @@ -1171,13 +1129,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerde(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties if (isUsingHiveMetastore) { @@ -1187,8 +1145,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getTableMetadata(tableIdent).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde)) } else { assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) } @@ -1229,18 +1191,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, spec, tableIdent) createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties if (isUsingHiveMetastore) { @@ -1250,8 +1212,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getPartition(tableIdent, spec).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde)) } else { assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) } @@ -1295,6 +1261,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1303,11 +1272,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition @@ -1354,6 +1320,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1362,7 +1331,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) @@ -1370,9 +1339,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { createTablePartition(catalog, part5, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4, part5)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic drop partition sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") @@ -1407,20 +1373,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") val part2 = Map("a" -> "2", "b" -> "c") val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") @@ -1451,14 +1417,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val resolver = spark.sessionState.conf.resolver val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getMetadata(colName: String): Metadata = { val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => resolver(field.name, colName) @@ -1601,13 +1567,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("drop current database") { - sql("CREATE DATABASE temp") - sql("USE temp") - sql("DROP DATABASE temp") - val e = intercept[AnalysisException] { + withDatabase("temp") { + sql("CREATE DATABASE temp") + sql("USE temp") + sql("DROP DATABASE temp") + val e = intercept[AnalysisException] { sql("CREATE TABLE t (a INT, b INT) USING parquet") }.getMessage - assert(e.contains("Database 'temp' not found")) + assert(e.contains("Database 'temp' not found")) + } } test("drop default database") { @@ -1837,22 +1805,25 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("tbl"), Row(1)) val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - - sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") - spark.catalog.refreshTable("tbl") - // SET LOCATION won't move data from previous table path to new table path. - assert(spark.table("tbl").count() == 0) - // the previous table path should be still there. - assert(new File(defaultTablePath).exists()) - - sql("INSERT INTO tbl SELECT 2") - checkAnswer(spark.table("tbl"), Row(2)) - // newly inserted data will go to the new table path. - assert(dir.listFiles().nonEmpty) - - sql("DROP TABLE tbl") - // the new table path will be removed after DROP TABLE. - assert(!dir.exists()) + try { + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") + spark.catalog.refreshTable("tbl") + // SET LOCATION won't move data from previous table path to new table path. + assert(spark.table("tbl").count() == 0) + // the previous table path should be still there. + assert(new File(defaultTablePath).exists()) + + sql("INSERT INTO tbl SELECT 2") + checkAnswer(spark.table("tbl"), Row(2)) + // newly inserted data will go to the new table path. + assert(dir.listFiles().nonEmpty) + + sql("DROP TABLE tbl") + // the new table path will be removed after DROP TABLE. + assert(!dir.exists()) + } finally { + Utils.deleteRecursively(new File(defaultTablePath)) + } } } } @@ -1864,7 +1835,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$dir") + |OPTIONS(path "${dir.toURI}") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1882,12 +1853,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("t"), Row("c", 1) :: Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.getAbsolutePath + val newDir = newDirFile.toURI spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") spark.sessionState.catalog.refreshTable(TableIdentifier("t")) val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table1.location == new URI(newDir)) + assert(table1.location == newDir) assert(!newDirFile.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") @@ -1905,7 +1876,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$dir" + |LOCATION "${dir.toURI}" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1931,7 +1902,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$dir") + |OPTIONS(path "${dir.toURI}") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1960,7 +1931,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$dir" + |LOCATION "${dir.toURI}" """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1977,7 +1948,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { test("create datasource table with a non-existing location") { withTable("t", "t1") { withTempPath { dir => - spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '$dir'") + spark.sql(s"CREATE TABLE t(a int, b int) USING parquet LOCATION '${dir.toURI}'") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1989,7 +1960,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } // partition table withTempPath { dir => - spark.sql(s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '$dir'") + spark.sql( + s"CREATE TABLE t1(a int, b int) USING parquet PARTITIONED BY(a) LOCATION '${dir.toURI}'") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -2014,7 +1986,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t |USING parquet - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -2030,7 +2002,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t1 |USING parquet |PARTITIONED BY(a, b) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -2047,6 +2019,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => test(s"data source table:partition column name containing $specialChars") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + withTable("t") { withTempDir { dir => spark.sql( @@ -2054,14 +2030,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { |CREATE TABLE t(a string, `$specialChars` string) |USING parquet |PARTITIONED BY(`$specialChars`) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' """.stripMargin) assert(dir.listFiles().isEmpty) spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" val partFile = new File(dir, partEscaped) - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1", "2") :: Nil) } } @@ -2070,15 +2046,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for datasource table") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + withTable("t", "t1") { withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t(a string) |USING parquet - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -2087,19 +2070,22 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(loc.listFiles().isEmpty) spark.sql("INSERT INTO TABLE t SELECT 1") - assert(loc.listFiles().length >= 1) + assert(loc.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1") :: Nil) } withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t1(a string, b string) |USING parquet |PARTITIONED BY(b) - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -2109,15 +2095,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(loc.listFiles().isEmpty) spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") val partFile = new File(loc, "b=2") - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") assert(!partFile1.exists()) - val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") - assert(partFile2.listFiles().length >= 1) - checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + + if (!Utils.isWindows) { + // Actual path becomes "b=2017-03-03%2012%3A13%253A14" on Windows. + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().nonEmpty) + checkAnswer( + spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } } } } @@ -2125,11 +2116,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for database") { - try { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + + withDatabase ("tmpdb") { withTable("t") { withTempDir { dir => val loc = new File(dir, specialChars) - spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$escapedLoc'") spark.sql("USE tmpdb") import testImplicits._ @@ -2140,8 +2138,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(tblloc.listFiles().nonEmpty) } } - } finally { - spark.sql("DROP DATABASE IF EXISTS tmpdb") } } } @@ -2150,11 +2146,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withTable("t", "t1") { withTempDir { dir => assert(!dir.getAbsolutePath.startsWith("file:/")) + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t(a string) |USING parquet - |LOCATION '$dir' + |LOCATION '$escapedDir' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location.toString.startsWith("file:/")) @@ -2162,12 +2161,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { withTempDir { dir => assert(!dir.getAbsolutePath.startsWith("file:/")) + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedDir = dir.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t1(a string, b string) |USING parquet |PARTITIONED BY(b) - |LOCATION '$dir' + |LOCATION '$escapedDir' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) assert(table.location.toString.startsWith("file:/")) @@ -2279,17 +2281,27 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { }.getMessage assert(e.contains("Found duplicate column(s)")) } else { - if (isUsingHiveMetastore) { - // hive catalog will still complains that c1 is duplicate column name because hive - // identifiers are case insensitive. - val e = intercept[AnalysisException] { - sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") - }.getMessage - assert(e.contains("HiveException")) - } else { - sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") - assert(spark.table("t1").schema - .equals(new StructType().add("c1", IntegerType).add("C1", StringType))) + sql("ALTER TABLE t1 ADD COLUMNS (C1 string)") + assert(spark.table("t1").schema == + new StructType().add("c1", IntegerType).add("C1", StringType)) + } + } + } + } + + test(s"basic DDL using locale tr - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withLocale("tr") { + val dbName = "DaTaBaSe_I" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + + val tabName = "tAb_I" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(col_I int) USING PARQUET") + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1") + checkAnswer(sql(s"SELECT col_I FROM $tabName"), Row(1) :: Nil) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index a9511cbd9e4cf..b4616826e40b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} @@ -236,6 +237,17 @@ class FileIndexSuite extends SharedSQLContext { val fileStatusCache = FileStatusCache.getOrCreate(spark) fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) } + + test("SPARK-20367 - properly unescape column names in inferPartitioning") { + withTempPath { path => + val colToUnescape = "Column/#%'?" + spark + .range(1) + .select(col("id").as(colToUnescape), col("id")) + .write.partitionBy(colToUnescape).parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).schema.exists(_.name == colToUnescape)) + } + } } class FakeParentPathFileSystem extends RawLocalFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index f36162858bf7a..fa3c69612704d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ - protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1") + protected override def sparkConf = super.sparkConf.set("spark.default.parallelism", "1") test("unpartitioned table, single partition") { val table = @@ -395,7 +395,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fileCatalog = new InMemoryFileIndex( sparkSession = spark, - rootPaths = Seq(new Path(tempDir)), + rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 352dba79a4c08..89d9b69dec7ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -261,10 +261,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val cars = spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -284,11 +284,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for FAILFAST parsing mode") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val exception = intercept[SparkException] { spark.read .format("csv") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() } @@ -990,13 +990,13 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SPARK-18699 put malformed records in a `columnNameOfCorruptRecord` field") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val schema = new StructType().add("a", IntegerType).add("b", TimestampType) // We use `PERMISSIVE` mode by default if invalid string is given. val df1 = spark .read .option("mode", "abcd") - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema) .csv(testFile(valueMalformedFile)) checkAnswer(df1, @@ -1011,7 +1011,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "Permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, @@ -1028,7 +1028,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "permissive") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, @@ -1041,7 +1041,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .read .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .schema(schema.add(columnNameOfCorruptRecord, IntegerType)) .csv(testFile(valueMalformedFile)) .collect @@ -1073,7 +1073,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read .option("header", true) - .option("wholeFile", true) + .option("multiLine", true) .csv(path.getAbsolutePath) // Check if headers have new lines in the names. @@ -1096,10 +1096,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Empty file produces empty dataframe with empty schema") { - Seq(false, true).foreach { wholeFile => + Seq(false, true).foreach { multiLine => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", wholeFile) + .option("multiLine", multiLine) .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 2ab03819964be..f8eb5c569f9ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.fasterxml.jackson.core.JsonFactory import org.apache.hadoop.fs.{Path, PathFilter} @@ -1803,7 +1804,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .option("compression", "gZiP") @@ -1825,7 +1826,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write.json(jsonDir) @@ -1854,7 +1855,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDF = spark.read.option("multiLine", true).json(path) // no corrupt record column should be created assert(jsonDF.schema === StructType(Seq())) // only the first object should be read @@ -1875,7 +1876,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "PERMISSIVE").json(path) assert(jsonDF.count() === corruptRecordCount) assert(jsonDF.schema === new StructType() .add("_corrupt_record", StringType) @@ -1906,7 +1907,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .write .text(path) - val jsonDF = spark.read.option("wholeFile", true).option("mode", "DROPMALFORMED").json(path) + val jsonDF = spark.read.option("multiLine", true).option("mode", "DROPMALFORMED").json(path) checkAnswer(jsonDF, Seq(Row("test"))) } } @@ -1929,7 +1930,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .json(path) } @@ -1937,7 +1938,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val exceptionTwo = intercept[SparkException] { spark.read - .option("wholeFile", true) + .option("multiLine", true) .option("mode", "FAILFAST") .schema(schema) .json(path) @@ -1978,4 +1979,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) } } + + test("SPARK-18772: Parse special floats correctly") { + val jsons = Seq( + """{"a": "NaN"}""", + """{"a": "Infinity"}""", + """{"a": "-Infinity"}""") + + // positive cases + val checks: Seq[Double => Boolean] = Seq( + _.isNaN, + _.isPosInfinity, + _.isNegInfinity) + + Seq(FloatType, DoubleType).foreach { dt => + jsons.zip(checks).foreach { case (json, check) => + val ds = spark.read + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(json).toDS()) + .select($"a".cast(DoubleType)).as[Double] + assert(check(ds.first())) + } + } + + // negative cases + Seq(FloatType, DoubleType).foreach { dt => + val lowerCasedJsons = jsons.map(_.toLowerCase(Locale.ROOT)) + // The special floats are case-sensitive so these cases below throw exceptions. + lowerCasedJsons.foreach { lowerCasedJson => + val e = intercept[SparkException] { + spark.read + .option("mode", "FAILFAST") + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(lowerCasedJson).toDS()) + .collect() + } + assert(e.getMessage.contains("Cannot parse")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9a3328fcecee8..98427cfe3031c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.{AccumulatorContext, LongAccumulator} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val path = s"${dir.getCanonicalPath}/table" (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) - Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { - val accu = new LongAccumulator - accu.register(sparkContext, Some("numRowGroups")) + Seq(true, false).foreach { enablePushDown => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition(_.foreach(v => accu.add(0))) df.collect - val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") - assert(numRowGroups.isDefined) - assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + if (enablePushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } AccumulatorContext.remove(accu.id) } } @@ -536,4 +538,43 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // scalastyle:on nonascii } } + + test("SPARK-20364: Disable Parquet predicate pushdown for fields having dots in the names") { + import testImplicits._ + + Seq(true, false).foreach { vectorized => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString, + SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> true.toString) { + withTempPath { path => + Seq(Some(1), None).toDF("col.dots").write.parquet(path.getAbsolutePath) + val readBack = spark.read.parquet(path.getAbsolutePath).where("`col.dots` IS NOT NULL") + assert(readBack.count() == 1) + } + } + } + } +} + +class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { + private var _sum = 0 + + override def isZero: Boolean = _sum == 0 + + override def copy(): AccumulatorV2[Integer, Integer] = { + val acc = new NumRowGroupsAcc() + acc._sum = _sum + acc + } + + override def reset(): Unit = _sum = 0 + + override def add(v: Integer): Unit = _sum += v + + override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match { + case a: NumRowGroupsAcc => _sum += a._sum + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + override def value: Integer = _sum } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index b4f3de9961209..7225693e50279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1022,4 +1022,16 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } } + + test("SPARK-22109: Resolve type conflicts between strings and timestamps in partition column") { + val df = Seq( + (1, "2015-01-01 00:00:00"), + (2, "2014-01-01 00:00:00"), + (3, "blah")).toDF("i", "str") + + withTempPath { path => + df.write.format("parquet").partitionBy("str").save(path.getAbsolutePath) + checkAnswer(spark.read.load(path.getAbsolutePath), df) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index c36609586c807..2efff3f57d7d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -23,7 +23,7 @@ import java.sql.Timestamp import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.SparkException +import org.apache.spark.{DebugFilesystem, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow @@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + /** + * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop + * to increase the chance of failure + */ + ignore("SPARK-20407 ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + for (i <- 1 to 100) { + DebugFilesystem.clearOpenStreams() + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreCorruptFiles() + } + assert(exception.getMessage().contains("is not a Parquet file")) + } + DebugFilesystem.assertNoOpenStreams() + } + } + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 26c45e092dc65..afb8ced53e25c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -157,7 +157,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("broadcast hint in SQL") { - import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} + import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join} spark.range(10).createOrReplaceTempView("t") spark.range(10).createOrReplaceTempView("u") @@ -170,12 +170,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution .optimizedPlan - assert(plan1.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(!plan1.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) - assert(!plan2.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(plan2.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) - assert(!plan3.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(!plan3.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2ce7db6a22c01..79d1fbfa3f072 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -143,6 +143,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("ObjectHashAggregate metrics") { + // Assume the execution plan is + // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> ObjectHashAggregate(nodeId = 0) + val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 2L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 1L))) + ) + + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).agg(collect_set('a)) + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("ObjectHashAggregate", Map("number of output rows" -> 4L)), + 0L -> ("ObjectHashAggregate", Map("number of output rows" -> 3L))) + ) + } + test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) @@ -270,6 +288,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("SortMergeJoin(left-anti) metrics") { + val anti = testData2.filter("a > 2") + withTempView("antiData") { + anti.createOrReplaceTempView("antiData") + val df = spark.sql( + "SELECT * FROM testData2 ANTI JOIN antiData ON testData2.a = antiData.a") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("SortMergeJoin", Map("number of output rows" -> 4L))) + ) + } + } + test("save metrics") { withTempPath { file => val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 20ac06f048c6f..3d480b148db55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") import CompactibleFileStreamLog._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 662c4466b21b2..48e70e48b1799 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -38,8 +38,8 @@ import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") private implicit def toOption[A](a: A): Option[A] = Option(a) @@ -259,6 +259,23 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { fm.rename(path2, path3) } } + + test("verifyBatchIds") { + import HDFSMetadataLog.verifyBatchIds + verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) + verifyBatchIds(Seq(1L), Some(1L), Some(1L)) + verifyBatchIds(Seq(1L, 2L, 3L), None, Some(3L)) + verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), None) + verifyBatchIds(Seq(1L, 2L, 3L), None, None) + + intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), None)) + intercept[IllegalStateException](verifyBatchIds(Seq(), None, Some(1L))) + intercept[IllegalStateException](verifyBatchIds(Seq(), Some(1L), Some(1L))) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), None)) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), None, Some(5L))) + intercept[IllegalStateException](verifyBatchIds(Seq(2, 3, 4), Some(1L), Some(5L))) + intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) + } } /** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala new file mode 100644 index 0000000000000..bdba536425a43 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} +import org.apache.spark.util.ManualClock + +class RateSourceSuite extends StreamTest { + + import testImplicits._ + + case class AdvanceRateManualClock(seconds: Long) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + assert(query.nonEmpty) + val rateSource = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] => + source.asInstanceOf[RateStreamSource] + }.head + rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) + (rateSource, rateSource.getOffset.get) + } + } + + test("basic") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("useManualClock", "true") + .load() + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*), + StopStream, + StartStream(), + // Advance 2 seconds because creating a new RateSource will also create a new ManualClock + AdvanceRateManualClock(seconds = 2), + CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*) + ) + } + + test("uniform distribution of event timestamps") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "1500") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + val expectedAnswer = (0 until 1500).map { v => + (math.round(v * (1000.0 / 1500)), v) + } + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch(expectedAnswer: _*) + ) + } + + test("valueAtSecond") { + import RateStreamSource._ + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8) + + assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0) + assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2) + assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6) + assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12) + assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20) + assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30) + } + + test("rampUpTime") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("rampUpTime", "4s") + .option("useManualClock", "true") + .load() + .as[(java.sql.Timestamp, Long)] + .map(v => (v._1.getTime, v._2)) + testStream(input)( + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch({ + Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11) + }: _*), // speed = 6 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8 + AdvanceRateManualClock(seconds = 1), + // Now we should reach full speed + CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10 + AdvanceRateManualClock(seconds = 1), + CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10 + ) + } + + test("numPartitions") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", "10") + .option("numPartitions", "6") + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(1), + CheckLastBatch((0 until 6): _*) + ) + } + + testQuietly("overflow") { + val input = spark.readStream + .format("rate") + .option("rowsPerSecond", Long.MaxValue.toString) + .option("useManualClock", "true") + .load() + .select(spark_partition_id()) + .distinct() + testStream(input)( + AdvanceRateManualClock(2), + ExpectFailure[ArithmeticException](t => { + Seq("overflow", "rowsPerSecond").foreach { msg => + assert(t.getMessage.contains(msg)) + } + }) + ) + } + + testQuietly("illegal option values") { + def testIllegalOptionValue( + option: String, + value: String, + expectedMessages: Seq[String]): Unit = { + val e = intercept[StreamingQueryException] { + spark.readStream + .format("rate") + .option(option, value) + .load() + .writeStream + .format("console") + .start() + .awaitTermination() + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + for (msg <- expectedMessages) { + assert(e.getCause.getMessage.contains(msg)) + } + } + + testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive")) + testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index ebb7422765ebb..cc09b2d5b7763 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -314,7 +314,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) - conf.set("fs.default.name", "fake:///") + conf.set("fs.defaultFS", "fake:///") val provider = newStoreProvider(hadoopConf = conf) provider.getStore(0).commit() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala new file mode 100644 index 0000000000000..19b93c9257212 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { + + var testVector: ColumnVector = _ + + private def allocate(capacity: Int, dt: DataType): ColumnVector = { + new OnHeapColumnVector(capacity, dt) + } + + override def afterEach(): Unit = { + testVector.close() + } + + test("boolean") { + testVector = allocate(10, BooleanType) + (0 until 10).foreach { i => + testVector.appendBoolean(i % 2 == 0) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getBoolean(i) === (i % 2 == 0)) + } + } + + test("byte") { + testVector = allocate(10, ByteType) + (0 until 10).foreach { i => + testVector.appendByte(i.toByte) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getByte(i) === (i.toByte)) + } + } + + test("short") { + testVector = allocate(10, ShortType) + (0 until 10).foreach { i => + testVector.appendShort(i.toShort) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getShort(i) === (i.toShort)) + } + } + + test("int") { + testVector = allocate(10, IntegerType) + (0 until 10).foreach { i => + testVector.appendInt(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getInt(i) === i) + } + } + + test("long") { + testVector = allocate(10, LongType) + (0 until 10).foreach { i => + testVector.appendLong(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getLong(i) === i) + } + } + + test("float") { + testVector = allocate(10, FloatType) + (0 until 10).foreach { i => + testVector.appendFloat(i.toFloat) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getFloat(i) === i.toFloat) + } + } + + test("double") { + testVector = allocate(10, DoubleType) + (0 until 10).foreach { i => + testVector.appendDouble(i.toDouble) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getDouble(i) === i.toDouble) + } + } + + test("string") { + testVector = allocate(10, StringType) + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.getUTF8String(i) === UTF8String.fromString(s"str$i")) + } + } + + test("binary") { + testVector = allocate(10, BinaryType) + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + assert(array.getBinary(i) === utf8) + } + } + + test("array") { + val arrayType = ArrayType(IntegerType, true) + testVector = allocate(10, arrayType) + + val data = testVector.arrayData() + var i = 0 + while (i < 6) { + data.putInt(i, i) + i += 1 + } + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + testVector.putArray(0, 0, 1) + testVector.putArray(1, 1, 2) + testVector.putArray(2, 3, 0) + testVector.putArray(3, 3, 3) + + val array = new ColumnVector.Array(testVector) + + assert(array.getArray(0).toIntArray() === Array(0)) + assert(array.getArray(1).asInstanceOf[ArrayData].toIntArray() === Array(1, 2)) + assert(array.getArray(2).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int]) + assert(array.getArray(3).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) + } + + test("struct") { + val schema = new StructType().add("int", IntegerType).add("double", DoubleType) + testVector = allocate(10, schema) + val c1 = testVector.getChildColumn(0) + val c2 = testVector.getChildColumn(1) + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + + val array = new ColumnVector.Array(testVector) + + assert(array.getStruct(0, 2).asInstanceOf[ColumnarBatch.Row].getInt(0) === 123) + assert(array.getStruct(0, 2).asInstanceOf[ColumnarBatch.Row].getDouble(1) === 3.45) + assert(array.getStruct(1, 2).asInstanceOf[ColumnarBatch.Row].getInt(0) === 456) + assert(array.getStruct(1, 2).asInstanceOf[ColumnarBatch.Row].getDouble(1) === 5.67) + } + + test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { + val arrayType = ArrayType(IntegerType, true) + testVector = new OffHeapColumnVector(8, arrayType) + + val data = testVector.arrayData() + (0 until 8).foreach(i => data.putInt(i, i)) + (0 until 8).foreach(i => testVector.putArray(i, i, 1)) + + // Increase vector's capacity and reallocate the data to new bigger buffers. + testVector.reserve(16) + + // Check that none of the values got lost/overwritten. + val array = new ColumnVector.Array(testVector) + (0 until 8).foreach { i => + assert(array.getArray(i).toIntArray() === Array(i)) + } + } + + test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { + val structType = new StructType().add("int", IntegerType).add("double", DoubleType) + testVector = new OffHeapColumnVector(8, structType) + (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) + testVector.reserve(16) + (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index a283ff971adcd..948f179f5e8f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -270,4 +270,15 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { val e2 = intercept[AnalysisException](spark.conf.unset(SCHEMA_STRING_LENGTH_THRESHOLD.key)) assert(e2.message.contains("Cannot modify the value of a static config")) } + + test("SPARK-21588 SQLContext.getConf(key, null) should return null") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + assert("1" == spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key, null)) + assert("1" == spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key, "")) + } + + assert(spark.conf.getOption("spark.sql.nonexistent").isEmpty) + assert(null == spark.conf.get("spark.sql.nonexistent", null)) + assert("" == spark.conf.get("spark.sql.nonexistent", "")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5bd36ec25ccb0..ae4b1bf0f47cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -96,6 +96,15 @@ class JDBCSuite extends SparkFunSuite | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') """.stripMargin.replaceAll("\n", " ")) + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW partsoverflow + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | partitionColumn 'THEID', lowerBound '-9223372036854775808', + | upperBound '9223372036854775807', numPartitions '3') + """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " + "d SMALLINT, e BIGINT)").executeUpdate() conn.prepareStatement("insert into test.inttypes values (1, false, 3, 4, 1234567890123)" @@ -367,6 +376,12 @@ class JDBCSuite extends SparkFunSuite assert(ids(2) === 3) } + test("overflow of partition bound difference does not give negative stride") { + val df = sql("SELECT * FROM partsoverflow") + checkNumPartitions(df, expectedNumPartitions = 3) + assert(df.collect().length == 3) + } + test("Register JDBC query with renamed fields") { // Regression test for bug SPARK-7345 sql( @@ -397,6 +412,28 @@ class JDBCSuite extends SparkFunSuite assert(e.contains("Invalid value `-1` for parameter `fetchsize`")) } + test("Missing partition columns") { + withView("tempPeople") { + val e = intercept[IllegalArgumentException] { + sql( + s""" + |CREATE OR REPLACE TEMPORARY VIEW tempPeople + |USING org.apache.spark.sql.jdbc + |OPTIONS ( + | url 'jdbc:h2:mem:testdb0;user=testUser;password=testPass', + | dbtable 'TEST.PEOPLE', + | lowerBound '0', + | upperBound '52', + | numPartitions '53', + | fetchSize '10000' ) + """.stripMargin.replaceAll("\n", " ")) + }.getMessage + assert(e.contains("When reading JDBC data sources, users need to specify all or none " + + "for the following options: 'partitionColumn', 'lowerBound', 'upperBound', and " + + "'numPartitions'")) + } + } + test("Basic API with FetchSize") { (0 to 4).foreach { size => val properties = new Properties() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index bf1fd160704fa..d7ae45680fe56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -323,8 +323,9 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { .option("partitionColumn", "foo") .save() }.getMessage - assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + - " and 'numPartitions' are required.")) + assert(e.contains("When reading JDBC data sources, users need to specify all or none " + + "for the following options: 'partitionColumn', 'lowerBound', 'upperBound', and " + + "'numPartitions'")) } test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9b65419dba234..ba0ca666b5c14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -90,6 +90,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column @@ -98,7 +99,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) val matchedBuckets = new BitSet(numBuckets) bucketValues.foreach { value => - matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 85ba33e58a787..b5fb740b6eb77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -19,26 +19,39 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types._ // please note that the META-INF/services had to be modified for the test directory for this to work class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { - test("data sources with the same name") { - intercept[RuntimeException] { + test("data sources with the same name - internal data sources") { + val e = intercept[AnalysisException] { spark.read.format("Fluet da Bomb").load() } + assert(e.getMessage.contains("Multiple sources found for Fluet da Bomb")) + } + + test("data sources with the same name - internal data source/external data source") { + assert(spark.read.format("datasource").load().schema == + StructType(Seq(StructField("longType", LongType, nullable = false)))) + } + + test("data sources with the same name - external data sources") { + val e = intercept[AnalysisException] { + spark.read.format("Fake external source").load() + } + assert(e.getMessage.contains("Multiple sources found for Fake external source")) } test("load data source from format alias") { - spark.read.format("gathering quorum").load().schema == - StructType(Seq(StructField("stringType", StringType, nullable = false))) + assert(spark.read.format("gathering quorum").load().schema == + StructType(Seq(StructField("stringType", StringType, nullable = false)))) } test("specify full classname with duplicate formats") { - spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") - .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) + assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") + .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))) } test("should fail to load ORC without Hive Support") { @@ -63,7 +76,7 @@ class FakeSourceOne extends RelationProvider with DataSourceRegister { } } -class FakeSourceTwo extends RelationProvider with DataSourceRegister { +class FakeSourceTwo extends RelationProvider with DataSourceRegister { def shortName(): String = "Fluet da Bomb" @@ -72,7 +85,7 @@ class FakeSourceTwo extends RelationProvider with DataSourceRegister { override def sqlContext: SQLContext = cont override def schema: StructType = - StructType(Seq(StructField("stringType", StringType, nullable = false))) + StructType(Seq(StructField("integerType", IntegerType, nullable = false))) } } @@ -88,3 +101,16 @@ class FakeSourceThree extends RelationProvider with DataSourceRegister { StructType(Seq(StructField("stringType", StringType, nullable = false))) } } + +class FakeSourceFour extends RelationProvider with DataSourceRegister { + + def shortName(): String = "datasource" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("longType", LongType, nullable = false))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index b16c9f8fc96b2..735e07c21373a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} import org.apache.spark.sql.execution.datasources.DataSourceAnalysis import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { } Seq(true, false).foreach { caseSensitive => - val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + def cast(e: Expression, dt: DataType): Expression = { + Cast(e, dt, Option(conf.sessionLocalTimeZone)) + } + val rule = DataSourceAnalysis(conf) test( s"convertStaticPartitions only handle INSERT having at least static partitions " + s"(caseSensitive: $caseSensitive)") { @@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { if (!caseSensitive) { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), @@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")), @@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { // Test the case having a single static partition column. { val nonPartitionedAttributes = Seq('e.int, 'f.int) - val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType)) + val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1")), @@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { val dynamicPartitionAttributes = Seq('g.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType)) ++ + Seq(cast(Literal("1"), IntegerType)) ++ dynamicPartitionAttributes val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 2eae66dda88de..41abff2a5da25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -345,4 +345,25 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } } + + test("SPARK-21203 wrong results of insertion of Array of Struct") { + val tabName = "tab1" + withTable(tabName) { + spark.sql( + """ + |CREATE TABLE `tab1` + |(`custom_fields` ARRAY>) + |USING parquet + """.stripMargin) + spark.sql( + """ + |INSERT INTO `tab1` + |SELECT ARRAY(named_struct('id', 1, 'value', 'a'), named_struct('id', 2, 'value', 'b')) + """.stripMargin) + + checkAnswer( + spark.sql("SELECT custom_fields.id, custom_fields.value FROM tab1"), + Row(Array(1, 2), Array("a", "b"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index 6dd4847ead738..c25c3f62158cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -92,12 +92,12 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { s""" |CREATE TABLE src |USING ${classOf[TestOptionsSource].getCanonicalName} - |OPTIONS (PATH '$p') + |OPTIONS (PATH '${p.toURI}') |AS SELECT 1 """.stripMargin) assert( spark.table("src").schema.head.metadata.getString("path") == - p.getAbsolutePath) + p.toURI.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala new file mode 100644 index 0000000000000..0dfd75e709123 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/fakeExternalSources.scala @@ -0,0 +1,64 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.fakesource + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} +import org.apache.spark.sql.types._ + + +// Note that the package name is intendedly mismatched in order to resemble external data sources +// and test the detection for them. +class FakeExternalSourceOne extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fake external source" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("stringType", StringType, nullable = false))) + } +} + +class FakeExternalSourceTwo extends RelationProvider with DataSourceRegister { + + def shortName(): String = "Fake external source" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("integerType", IntegerType, nullable = false))) + } +} + +class FakeExternalSourceThree extends RelationProvider with DataSourceRegister { + + def shortName(): String = "datasource" + + override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation = + new BaseRelation { + override def sqlContext: SQLContext = cont + + override def schema: StructType = + StructType(Seq(StructField("byteType", ByteType, nullable = false))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index a15c2cff930fc..e858b7d9998a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -268,4 +268,17 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { CheckLastBatch(7) ) } + + test("SPARK-21546: dropDuplicates should ignore watermark when it's not a key") { + val input = MemoryStream[(Int, Int)] + val df = input.toDS.toDF("id", "time") + .withColumn("time", $"time".cast("timestamp")) + .withWatermark("time", "1 second") + .dropDuplicates("id") + .select($"id", $"time".cast("long")) + testStream(df)( + AddData(input, 1 -> 1, 1 -> 2, 2 -> 2), + CheckLastBatch(1 -> 1, 2 -> 2) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index fd850a7365e20..4f19fa0bb4a97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -21,7 +21,7 @@ import java.{util => ju} import java.text.SimpleDateFormat import java.util.Date -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.OutputMode._ -class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Logging { +class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matchers with Logging { import testImplicits._ @@ -38,6 +38,43 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin sqlContext.streams.active.foreach(_.stop()) } + test("EventTimeStats") { + val epsilon = 10E-6 + + val stats = EventTimeStats(max = 100, min = 10, avg = 20.0, count = 5) + stats.add(80L) + stats.max should be (100) + stats.min should be (10) + stats.avg should be (30.0 +- epsilon) + stats.count should be (6) + + val stats2 = EventTimeStats(80L, 5L, 15.0, 4) + stats.merge(stats2) + stats.max should be (100) + stats.min should be (5) + stats.avg should be (24.0 +- epsilon) + stats.count should be (10) + } + + test("EventTimeStats: avg on large values") { + val epsilon = 10E-6 + val largeValue = 10000000000L // 10B + // Make sure `largeValue` will cause overflow if we use a Long sum to calc avg. + assert(largeValue * largeValue != BigInt(largeValue) * BigInt(largeValue)) + val stats = + EventTimeStats(max = largeValue, min = largeValue, avg = largeValue, count = largeValue - 1) + stats.add(largeValue) + stats.avg should be (largeValue.toDouble +- epsilon) + + val stats2 = EventTimeStats( + max = largeValue + 1, + min = largeValue, + avg = largeValue + 1, + count = largeValue) + stats.merge(stats2) + stats.avg should be ((largeValue + 0.5) +- epsilon) + } + test("error on bad column") { val inputData = MemoryStream[Int].toDF() val e = intercept[AnalysisException] { @@ -344,6 +381,44 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin assert(eventTimeColumns(0).name === "second") } + test("EventTime watermark should be ignored in batch query.") { + val df = testData + .withColumn("eventTime", $"key".cast("timestamp")) + .withWatermark("eventTime", "1 minute") + .select("eventTime") + .as[Long] + + checkDataset[Long](df, 1L to 100L: _*) + } + + test("SPARK-21565: watermark operator accepts attributes from replacement") { + withTempDir { dir => + dir.delete() + + val df = Seq(("a", 100.0, new java.sql.Timestamp(100L))) + .toDF("symbol", "price", "eventTime") + df.write.json(dir.getCanonicalPath) + + val input = spark.readStream.schema(df.schema) + .json(dir.getCanonicalPath) + + val groupEvents = input + .withWatermark("eventTime", "2 seconds") + .groupBy("symbol", "eventTime") + .agg(count("price") as 'count) + .select("symbol", "eventTime", "count") + val q = groupEvents.writeStream + .outputMode("append") + .format("console") + .start() + try { + q.processAllAvailable() + } finally { + q.stop() + } + } + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1211242b9fbb4..bb6a27803bb20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming import java.util.Locale +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -62,6 +64,35 @@ class FileStreamSinkSuite extends StreamTest { } } + test("SPARK-21167: encode and decode path correctly") { + val inputData = MemoryStream[String] + val ds = inputData.toDS() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + val query = ds.map(s => (s, s.length)) + .toDF("value", "len") + .writeStream + .partitionBy("value") + .option("checkpointLocation", checkpointDir) + .format("parquet") + .start(outputDir) + + try { + // The output is partitoned by "value", so the value will appear in the file path. + // This is to test if we handle spaces in the path correctly. + inputData.addData("hello world") + failAfter(streamingTimeout) { + query.processAllAvailable() + } + val outputDf = spark.read.parquet(outputDir) + checkDatasetUnorderly(outputDf.as[(Int, String)], ("hello world".length, "hello world")) + } finally { + query.stop() + } + } + test("partitioned writing and batch reading") { val inputData = MemoryStream[Int] val ds = inputData.toDS() @@ -145,6 +176,43 @@ class FileStreamSinkSuite extends StreamTest { } } + test("partitioned writing and batch reading with 'basePath'") { + withTempDir { outputDir => + withTempDir { checkpointDir => + val outputPath = outputDir.getAbsolutePath + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + var query: StreamingQuery = null + + try { + query = + ds.map(i => (i, -i, i * 1000)) + .toDF("id1", "id2", "value") + .writeStream + .partitionBy("id1", "id2") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("parquet") + .start(outputPath) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val readIn = spark.read.option("basePath", outputPath).parquet(s"$outputDir/*/*") + checkDatasetUnorderly( + readIn.as[(Int, Int, Int)], + (1000, 1, -1), (2000, 2, -2), (3000, 3, -3)) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + // This tests whether FileStreamSink works with aggregations. Specifically, it tests // whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to // to execute the trigger for writing data to file sink. See SPARK-18440 for more details. @@ -266,4 +334,22 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("FileStreamSink.ancestorIsMetadataDirectory()") { + val hadoopConf = spark.sparkContext.hadoopConfiguration + def assertAncestorIsMetadataDirectory(path: String): Unit = + assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + def assertAncestorIsNotMetadataDirectory(path: String): Unit = + assert(!FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c/") + + assertAncestorIsNotMetadataDirectory(s"/a/b/c") + assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 2108b118bf059..e2ec690d90e52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -1314,6 +1314,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) + assert(metadataLog.add(1, Array(FileEntry(s"$scheme:///file2", 200L, 0)))) val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, dir.getAbsolutePath, Map.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 85aa7dbe9ed86..d7370642d08b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -73,14 +73,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(state.hasRemoved === shouldBeRemoved) } + // === Tests for state in streaming queries === // Updating empty state - state = new GroupStateImpl[String](None) + state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = new GroupStateImpl[String](Some("2")) + state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -99,24 +100,34 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("GroupState - setTimeout**** with NoTimeout") { - for (initState <- Seq(None, Some(5))) { - // for different initial state - implicit val state = new GroupStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) - testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + for (initValue <- Seq(None, Some(5))) { + val states = Seq( + GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), + GroupStateImpl.createForBatch(NoTimeout) + ) + for (state <- states) { + // for streaming queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } } } test("GroupState - setTimeout**** with ProcessingTimeTimeout") { - implicit var state: GroupStateImpl[Int] = null - - state = new GroupStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + // for streaming queries + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) - testTimeoutDurationNotAllowed[IllegalStateException](state) + state.setTimeoutDuration(500) + assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.update(5) - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + assert(state.getTimeoutTimestamp === 1500) // does not change state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") @@ -124,19 +135,38 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.remove() + assert(state.getTimeoutTimestamp === 3000) // does not change + state.setTimeoutDuration(500) // can still be set + assert(state.getTimeoutTimestamp === 1500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) - testTimeoutDurationNotAllowed[IllegalStateException](state) + state.setTimeoutDuration(500) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.update(5) + state.setTimeoutDuration(1000) + state.setTimeoutDuration("2 second") + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutDuration(500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } test("GroupState - setTimeout**** with EventTimeTimeout") { - implicit val state = new GroupStateImpl[Int]( - None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( + None, 1000, 1000, EventTimeTimeout, false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[IllegalStateException](state) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state state.update(5) + assert(state.getTimeoutTimestamp === 5000) // does not change state.setTimeoutTimestamp(10000) assert(state.getTimeoutTimestamp === 10000) state.setTimeoutTimestamp(new Date(20000)) @@ -144,9 +174,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.remove() + assert(state.getTimeoutTimestamp === 20000) + state.setTimeoutTimestamp(5000) + assert(state.getTimeoutTimestamp === 5000) // can be set after removing state + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + // for batch queries + state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) - testTimeoutTimestampNotAllowed[IllegalStateException](state) + state.setTimeoutTimestamp(5000) + + state.update(5) + state.setTimeoutTimestamp(10000) + state.setTimeoutTimestamp(new Date(20000)) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + state.setTimeoutTimestamp(5000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) } test("GroupState - illegal params to setTimeout****") { @@ -154,47 +200,86 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { - intercept[IllegalArgumentException] { body } + intercept[IllegalArgumentException] { + body + } assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new GroupStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) - testIllegalTimeout { state.setTimeoutDuration(-1000) } - testIllegalTimeout { state.setTimeoutDuration(0) } - testIllegalTimeout { state.setTimeoutDuration("-2 second") } - testIllegalTimeout { state.setTimeoutDuration("-1 month") } - testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + testIllegalTimeout { + state.setTimeoutDuration(-1000) + } + testIllegalTimeout { + state.setTimeoutDuration(0) + } + testIllegalTimeout { + state.setTimeoutDuration("-2 second") + } + testIllegalTimeout { + state.setTimeoutDuration("-1 month") + } + testIllegalTimeout { + state.setTimeoutDuration("1 month -1 day") + } - state = new GroupStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) - testIllegalTimeout { state.setTimeoutTimestamp(-10000) } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } - testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } - testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } + state = GroupStateImpl.createForStreaming( + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + testIllegalTimeout { + state.setTimeoutTimestamp(-10000) + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(10000, "1 month -1 day") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000)) + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-3 second") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "-1 month") + } + testIllegalTimeout { + state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") + } } test("GroupState - hasTimedOut") { for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + // for streaming queries for (initState <- Seq(None, Some(5))) { - val state1 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + val state1 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = false) assert(state1.hasTimedOut === false) - val state2 = new GroupStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + + val state2 = GroupStateImpl.createForStreaming( + initState, 1000, 1000, timeoutConf, hasTimedOut = true) assert(state2.hasTimedOut === true) } + + // for batch queries + assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) } } test("GroupState - primitive type") { - var intState = new GroupStateImpl[Int](None) + var intState = GroupStateImpl.createForStreaming[Int]( + None, 1000, 1000, NoTimeout, hasTimedOut = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) - intState = new GroupStateImpl[Int](Some(10)) + intState = GroupStateImpl.createForStreaming[Int]( + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -210,7 +295,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" @@ -318,6 +402,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } + // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), + // Try to remove these cases in the future + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + val testName = + if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, + timeoutConf = EventTimeTimeout, + priorState = Some(5), + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedException = classOf[IllegalStateException]) + } + // Tests for StateStoreUpdater.updateStateForTimedOutKeys() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { @@ -558,7 +680,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(ProcessingTime("1 second"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), CheckLastBatch(("a", "1")), @@ -589,7 +711,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } - test("flatMapGroupsWithState - streaming with event time timeout") { + test("flatMapGroupsWithState - streaming with event time timeout + watermark") { // Function to maintain the max event time // Returns the max event time in the state, or -1 if the state was removed by timeout val stateFunc = ( @@ -623,7 +745,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(ProcessingTime("1 second")), + StartStream(Trigger.ProcessingTime("1 second")), AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s AddData(inputData, ("a", 4)), // Add data older than watermark for "a" @@ -677,15 +799,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + // Test the following + // - no initial state + // - timeouts operations work, does not throw any error [SPARK-20792] + // - works with primitive state type + val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") + state.setTimeoutTimestamp(0, "1 hour") + state.update(10) (key, values.size) } checkAnswer( spark.createDataset(Seq("a", "a", "b")) .groupByKey(x => x) - .mapGroupsWithState(stateFunc) + .mapGroupsWithState(EventTimeTimeout)(stateFunc) .toDF, spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) } @@ -761,6 +889,44 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf assert(e.getMessage === "The output mode of function should be append or update") } + def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = { + test("SPARK-20714: watermark does not fail query when timeout = " + timeoutConf) { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = + (key: String, values: Iterator[(String, Long)], state: GroupState[RunningCount]) => { + if (state.hasTimedOut) { + state.remove() + Iterator((key, "-1")) + } else { + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + state.setTimeoutDuration("10 seconds") + Iterator((key, count.toString)) + } + } + + val clock = new StreamManualClock + val inputData = MemoryStream[(String, Long)] + val result = + inputData.toDF().toDF("key", "time") + .selectExpr("key", "cast(time as timestamp) as timestamp") + .withWatermark("timestamp", "10 second") + .as[(String, Long)] + .groupByKey(x => x._1) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), + AddData(inputData, ("a", 1L)), + AdvanceManualClock(1 * 1000), + CheckLastBatch(("a", "1")) + ) + } + } + testWithTimeout(NoTimeout) + testWithTimeout(ProcessingTimeTimeout) + def testStateUpdateWithData( testName: String, stateUpdates: GroupState[Int] => Unit, @@ -768,7 +934,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf priorState: Option[Int], priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, - expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP, + expectedException: Class[_ <: Exception] = null): Unit = { if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state @@ -782,7 +949,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } testStateUpdate( testTimeoutUpdates = false, mapGroupsFunc, timeoutConf, - priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + priorState, priorTimeoutTimestamp, + expectedState, expectedTimeoutTimestamp, expectedException) } } @@ -801,9 +969,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf stateUpdates(state) Iterator.empty } + testStateUpdate( testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf, - preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) + preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp, null) } } @@ -814,7 +983,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], - expectedTimeoutTimestamp: Long): Unit = { + expectedTimeoutTimestamp: Long, + expectedException: Class[_ <: Exception]): Unit = { val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( @@ -829,22 +999,30 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Call updating function to update state store - val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() - } else { - updater.updateStateForKeysWithData(Iterator(key)) + def callFunction() = { + val returnedIter = if (testTimeoutUpdates) { + updater.updateStateForTimedOutKeys() + } else { + updater.updateStateForKeysWithData(Iterator(key)) + } + returnedIter.size // consume the iterator to force state updates } - returnedIter.size // consumer the iterator to force state updates - - // Verify updated state in store - val updatedStateRow = store.get(key) - assert( - updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, - "final state not as expected") - if (updatedStateRow.nonEmpty) { + if (expectedException != null) { + // Call function and verify the exception type + val e = intercept[Exception] { callFunction() } + assert(e.getClass === expectedException, "Exception thrown but of the wrong type") + } else { + // Call function to update and verify updated state in store + callFunction() + val updatedStateRow = store.get(key) assert( - updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") + updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState, + "final state not as expected") + if (updatedStateRow.nonEmpty) { + assert( + updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 13fe51a557733..1fc062974e185 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -25,6 +25,8 @@ import scala.util.control.ControlThrowable import org.apache.commons.io.FileUtils +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand @@ -69,6 +71,27 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("SPARK-20432: union one stream with itself") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") + val unioned = df.union(df) + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = + unioned + .writeStream.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) + } finally { + query.stop() + } + } + } + } + test("union two streams") { val inputData1 = MemoryStream[Int] val inputData2 = MemoryStream[Int] @@ -120,6 +143,33 @@ class StreamSuite extends StreamTest { assertDF(df) } + test("Within the same streaming query, one StreamingRelation should only be transformed to one " + + "StreamingExecutionRelation") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() + var query: StreamExecution = null + try { + query = + df.union(df) + .writeStream + .format("memory") + .queryName("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + query.awaitInitialization(streamingTimeout.toMillis) + val executionRelations = + query + .logicalPlan + .collect { case ser: StreamingExecutionRelation => ser } + assert(executionRelations.size === 2) + assert(executionRelations.distinct.size === 1) + } finally { + if (query != null) { + query.stop() + } + } + } + test("unsupported queries") { val streamInput = MemoryStream[Int] val batchInput = Seq(1, 2, 3).toDS() @@ -500,6 +550,70 @@ class StreamSuite extends StreamTest { } } } + + test("calling stop() on a query cancels related jobs") { + val input = MemoryStream[Int] + val query = input + .toDS() + .map { i => + while (!org.apache.spark.TaskContext.get().isInterrupted()) { + // keep looping till interrupted by query.stop() + Thread.sleep(100) + } + i + } + .writeStream + .format("console") + .start() + + input.addData(1) + // wait for jobs to start + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().nonEmpty) + } + + query.stop() + // make sure jobs are stopped + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().isEmpty) + } + } + + test("batch id is updated correctly in the job description") { + val queryName = "memStream" + @volatile var jobDescription: String = null + def assertDescContainsQueryNameAnd(batch: Integer): Unit = { + // wait for listener event to be processed + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + assert(jobDescription.contains(queryName) && jobDescription.contains(s"batch = $batch")) + } + + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescription = jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION) + } + }) + + val input = MemoryStream[Int] + val query = input + .toDS() + .map(_ + 1) + .writeStream + .format("memory") + .queryName(queryName) + .start() + + input.addData(1) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 0) + input.addData(2, 3) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 1) + input.addData(4) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 2) + query.stop() + } } abstract class FakeSource extends StreamSourceProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 5bc36dd30f6d1..2a4039cc5831a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -172,8 +172,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { * * @param isFatalError if this is a fatal error. If so, the error should also be caught by * UncaughtExceptionHandler. + * @param assertFailure a function to verify the error. */ case class ExpectFailure[T <: Throwable : ClassTag]( + assertFailure: Throwable => Unit = _ => {}, isFatalError: Boolean = false) extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] override def toString(): String = @@ -455,6 +457,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") streamThreadDeathCause = null } + ef.assertFailure(exception.getCause) } catch { case _: InterruptedException => case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f796a4cb4a398..b6e82b621c8cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -69,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte ) } + test("count distinct") { + val inputData = MemoryStream[(Int, Seq[Int])] + + val aggregated = + inputData.toDF() + .select($"*", explode($"_2") as 'value) + .groupBy($"_1") + .agg(size(collect_set($"value"))) + .as[(Int, Int)] + + testStream(aggregated, Update)( + AddData(inputData, (1, Seq(1, 2))), + CheckLastBatch((1, 2)) + ) + } + test("simple count, complete mode") { val inputData = MemoryStream[Int] @@ -251,7 +267,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte .where('value >= current_timestamp().cast("long") - 10L) testStream(aggregated, Complete)( - StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock), // advance clock to 10 seconds, all keys retained AddData(inputData, 0L, 5L, 5L, 10L), @@ -278,7 +294,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte clock.advance(60 * 1000L) true }, - StartStream(ProcessingTime("10 seconds"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock), // The commit log blown, causing the last batch to re-run CheckLastBatch((20L, 1), (85L, 1)), AssertOnQuery { q => @@ -306,7 +322,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte .where($"value".cast("date") >= date_sub(current_date(), 10)) .select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)") testStream(aggregated, Complete)( - StartStream(ProcessingTime("10 day"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock), // advance clock to 10 days, should retain all keys AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10), @@ -330,7 +346,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60) true }, - StartStream(ProcessingTime("10 day"), triggerClock = clock), + StartStream(Trigger.ProcessingTime("10 day"), triggerClock = clock), // Commit log blown, causing a re-run of the last batch CheckLastBatch((20L, 1), (85L, 1)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b8a694c177310..59c6a6fade175 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.UUID import scala.collection.mutable import scala.concurrent.duration._ +import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics import org.scalatest.concurrent.AsyncAssertions.Waiter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index b49efa6890236..2986b7f1eecfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -78,9 +78,9 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { eventually(Timeout(streamingTimeout)) { require(!q2.isActive) require(q2.exception.isDefined) + assert(spark.streams.get(q2.id) === null) + assert(spark.streams.active.toSet === Set(q3)) } - assert(spark.streams.get(q2.id) === null) - assert(spark.streams.active.toSet === Set(q3)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index b69536ed37463..ee5af65cd71c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -613,6 +613,18 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("processAllAvailable should not block forever when a query is stopped") { + val input = MemoryStream[Int] + input.addData(1) + val query = input.toDF().writeStream + .trigger(Trigger.Once()) + .format("console") + .start() + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index dc2506a48ad00..bae9d811f7790 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -641,6 +641,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { test("temp checkpoint dir should be deleted if a query is stopped without errors") { import testImplicits._ val query = MemoryStream[Int].toDS.writeStream.format("console").start() + query.processAllAvailable() val checkpointDir = new Path( query.asInstanceOf[StreamingQueryWrapper].streamingQuery.checkpointRoot) val fs = checkpointDir.getFileSystem(spark.sessionState.newHadoopConf()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6a4cc95d36bea..f6d47734d7e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.test import java.io.File import java.net.URI import java.nio.file.Files -import java.util.UUID +import java.util.{Locale, UUID} +import scala.concurrent.duration._ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -49,7 +51,7 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils - extends SparkFunSuite + extends SparkFunSuite with Eventually with BeforeAndAfterAll with SQLTestData { self => @@ -138,6 +140,15 @@ private[sql] trait SQLTestUtils } } + /** + * Waits for all tasks on all executors to be finished. + */ + protected def waitForTasksToFinish(): Unit = { + eventually(timeout(10.seconds)) { + assert(spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -146,7 +157,11 @@ private[sql] trait SQLTestUtils */ protected def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) + try f(dir) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + Utils.deleteRecursively(dir) + } } /** @@ -222,12 +237,39 @@ private[sql] trait SQLTestUtils try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE ${DEFAULT_DATABASE}") + spark.sql(s"USE $DEFAULT_DATABASE") } spark.sql(s"DROP DATABASE $dbName CASCADE") } } + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name") + } + spark.sql(s"USE $DEFAULT_DATABASE") + } + } + + /** + * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM + * after `f` returns. + */ + protected def withLocale(language: String)(f: => Unit): Unit = { + val originalLocale = Locale.getDefault + try { + // Add Locale setting + Locale.setDefault(new Locale(language)) + f + } finally { + Locale.setDefault(originalLocale) + } + } + /** * Activates database `db` before executing `f`, then switches back to `default` database after * `f` returns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e122b39f6fc40..7cea4c02155ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,19 +17,22 @@ package org.apache.spark.sql.test +import scala.concurrent.duration._ + import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - protected val sparkConf = new SparkConf() + protected def sparkConf = { + new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + } /** * The [[TestSparkSession]] to use for all tests in this suite. @@ -50,8 +53,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { - new TestSparkSession( - sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + new TestSparkSession(sparkConf) } /** @@ -72,6 +74,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected override def afterAll(): Unit = { super.afterAll() if (_spark != null) { + _spark.sessionState.catalog.reset() _spark.stop() _spark = null } @@ -84,6 +87,10 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected override def afterEach(): Unit = { super.afterEach() - DebugFilesystem.assertNoOpenStreams() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9c879218ddc0d..8922c2b0a4670 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index ff3784cab9e26..1d1074a2a7387 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -253,6 +253,8 @@ private[hive] class SparkExecuteStatementOperation( return } else { setState(OperationState.ERROR) + HiveThriftServer2.listener.onStatementError( + statementId, e.getMessage, SparkUtils.exceptionString(e)) throw e } // Actually do need to catch Throwable as some failures don't inherit from Exception and diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index f39e9dcd3a5bb..38b8605745752 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -39,7 +39,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") val content = @@ -197,4 +198,3 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) } } - diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 0f249d7d59351..233e406c1d1c0 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../../pom.xml @@ -162,6 +162,10 @@ org.apache.thrift libfb303 + + org.apache.derby + derby + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8b0fdf49cefab..4ac825ffd1105 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -72,6 +72,11 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat classOf[TException].getCanonicalName, classOf[InvocationTargetException].getCanonicalName) + @transient + protected[sql] var hadoopFileSelector: Option[HadoopFileSelector] = None + + private[this] var tableNamePreprocessor: (String) => String = identity + /** * Whether this is an exception thrown by the hive client that should be wrapped. * @@ -114,7 +119,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * should interpret these special data source properties and restore the original table metadata * before returning it. */ - private def getRawTable(db: String, table: String): CatalogTable = withClient { + private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient { client.getTable(db, table) } @@ -137,17 +142,33 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + /** + * Checks the validity of column names. Hive metastore disallows the table to use comma in + * data column names. Partition columns do not have such a restriction. Views do not have such + * a restriction. + */ + private def verifyColumnNames(table: CatalogTable): Unit = { + if (table.tableType != VIEW) { + table.dataSchema.map(_.name).foreach { colName => + if (colName.contains(",")) { + throw new AnalysisException("Cannot create a table having a column whose name contains " + + s"commas in Hive metastore. Table: ${table.identifier}; Column: $colName") + } + } + } + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -194,7 +215,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -202,44 +223,42 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val table = tableDefinition.identifier.table requireDbExists(db) verifyTableProperties(tableDefinition) + verifyColumnNames(tableDefinition) if (tableExists(db, table) && !ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) } - if (tableDefinition.tableType == VIEW) { - client.createTable(tableDefinition, ignoreIfExists) + // Ideally we should not create a managed table with location, but Hive serde table can + // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have + // to create the table directory and write out data before we create this table, to avoid + // exposing a partial written table. + val needDefaultTableLocation = tableDefinition.tableType == MANAGED && + tableDefinition.storage.locationUri.isEmpty + + val tableLocation = if (needDefaultTableLocation) { + Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) } else { - // Ideally we should not create a managed table with location, but Hive serde table can - // specify location for managed table. And in [[CreateDataSourceTableAsSelectCommand]] we have - // to create the table directory and write out data before we create this table, to avoid - // exposing a partial written table. - val needDefaultTableLocation = tableDefinition.tableType == MANAGED && - tableDefinition.storage.locationUri.isEmpty - - val tableLocation = if (needDefaultTableLocation) { - Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) - } else { - tableDefinition.storage.locationUri - } + tableDefinition.storage.locationUri + } - if (DDLUtils.isHiveTable(tableDefinition)) { - val tableWithDataSourceProps = tableDefinition.copy( - // We can't leave `locationUri` empty and count on Hive metastore to set a default table - // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default - // table location for tables in default database, while we expect to use the location of - // default database. - storage = tableDefinition.storage.copy(locationUri = tableLocation), - // Here we follow data source tables and put table metadata like table schema, partition - // columns etc. in table properties, so that we can work around the Hive metastore issue - // about not case preserving and make Hive serde table support mixed-case column names. - properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) - client.createTable(tableWithDataSourceProps, ignoreIfExists) - } else { - createDataSourceTable( - tableDefinition.withNewStorage(locationUri = tableLocation), - ignoreIfExists) - } + if (DDLUtils.isDatasourceTable(tableDefinition)) { + createDataSourceTable( + tableDefinition.withNewStorage(locationUri = tableLocation), + ignoreIfExists) + } else { + val tableWithDataSourceProps = tableDefinition.copy( + // We can't leave `locationUri` empty and count on Hive metastore to set a default table + // location, because Hive metastore uses hive.metastore.warehouse.dir to generate default + // table location for tables in default database, while we expect to use the location of + // default database. + storage = tableDefinition.storage.copy(locationUri = tableLocation), + // Here we follow data source tables and put table metadata like table schema, partition + // columns etc. in table properties, so that we can work around the Hive metastore issue + // about not case preserving and make Hive serde table and view support mixed-case column + // names. + properties = tableDefinition.properties ++ tableMetaToTableProps(tableDefinition)) + client.createTable(tableWithDataSourceProps, ignoreIfExists) } } @@ -372,6 +391,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat * can be used as table properties later. */ private def tableMetaToTableProps(table: CatalogTable): mutable.Map[String, String] = { + tableMetaToTableProps(table, table.schema) + } + + private def tableMetaToTableProps( + table: CatalogTable, + schema: StructType): mutable.Map[String, String] = { val partitionColumns = table.partitionColumnNames val bucketSpec = table.bucketSpec @@ -380,7 +405,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // property. In this case, we split the JSON string and store each part as a separate table // property. val threshold = conf.get(SCHEMA_STRING_LENGTH_THRESHOLD) - val schemaJsonString = table.schema.json + val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq properties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString) @@ -456,7 +481,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -465,7 +490,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = withClient { val rawTable = getRawTable(db, oldName) // Note that Hive serde tables don't use path option in storage properties to store the value @@ -610,19 +638,29 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient { requireTableExists(db, table) val rawTable = getRawTable(db, table) - val withNewSchema = rawTable.copy(schema = schema) // Add table metadata such as table schema, partition columns, etc. to table properties. - val updatedTable = withNewSchema.copy( - properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) - try { - client.alterTable(updatedTable) - } catch { - case NonFatal(e) => - val warningMessage = - s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + - "compatible way. Updating Hive metastore in Spark SQL specific format." - logWarning(warningMessage, e) - client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema)) + val updatedProperties = rawTable.properties ++ tableMetaToTableProps(rawTable, schema) + val withNewSchema = rawTable.copy(properties = updatedProperties, schema = schema) + verifyColumnNames(withNewSchema) + + if (isDatasourceTable(rawTable)) { + // For data source tables, first try to write it with the schema set; if that does not work, + // try again with updated properties and the partition schema. This is a simplified version of + // what createDataSourceTable() does, and may leave the table in a state unreadable by Hive + // (for example, the schema does not match the data source schema, or does not match the + // storage descriptor). + try { + client.alterTable(withNewSchema) + } catch { + case NonFatal(e) => + val warningMessage = + s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " + + "compatible way. Updating Hive metastore in Spark SQL specific format." + logWarning(warningMessage, e) + client.alterTable(withNewSchema.copy(schema = rawTable.partitionSchema)) + } + } else { + client.alterTable(withNewSchema) } } @@ -630,10 +668,6 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat restoreTableMetadata(getRawTable(db, table)) } - override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient { - client.getTableOption(db, table).map(restoreTableMetadata) - } - /** * Restores table metadata from the table properties. This method is kind of a opposite version * of [[createTable]]. @@ -648,16 +682,21 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat var table = inputTable - if (table.tableType != VIEW) { - table.properties.get(DATASOURCE_PROVIDER) match { - // No provider in table properties, which means this is a Hive serde table. - case None => - table = restoreHiveSerdeTable(table) - - // This is a regular data source table. - case Some(provider) => - table = restoreDataSourceTable(table, provider) - } + table.properties.get(DATASOURCE_PROVIDER) match { + case None if table.tableType == VIEW => + // If this is a view created by Spark 2.2 or higher versions, we should restore its schema + // from table properties. + if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) { + table = table.copy(schema = getSchemaFromTableProperties(table)) + } + + // No provider in table properties, which means this is a Hive serde table. + case None => + table = restoreHiveSerdeTable(table) + + // This is a regular data source table. + case Some(provider) => + table = restoreDataSourceTable(table, provider) } // construct Spark's statistics from information in Hive metastore @@ -696,6 +735,20 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat properties = table.properties.filterNot { case (key, _) => key.startsWith(SPARK_SQL_PREFIX) }) } + // Reorder table schema to put partition columns at the end. Before Spark 2.2, the partition + // columns are not put at the end of schema. We need to reorder it when reading the schema + // from the table properties. + private def reorderSchema(schema: StructType, partColumnNames: Seq[String]): StructType = { + val partitionFields = partColumnNames.map { partCol => + schema.find(_.name == partCol).getOrElse { + throw new AnalysisException("The metadata is corrupted. Unable to find the " + + s"partition column names from the schema. schema: ${schema.catalogString}. " + + s"Partition columns: ${partColumnNames.mkString("[", ", ", "]")}") + } + } + StructType(schema.filterNot(partitionFields.contains) ++ partitionFields) + } + private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = { val hiveTable = table.copy( provider = Some(DDLUtils.HIVE_PROVIDER), @@ -705,10 +758,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // schema from table properties. if (table.properties.contains(DATASOURCE_SCHEMA_NUMPARTS)) { val schemaFromTableProps = getSchemaFromTableProperties(table) - if (DataType.equalsIgnoreCaseAndNullability(schemaFromTableProps, table.schema)) { + val partColumnNames = getPartitionColumnsFromTableProperties(table) + val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) + + if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema)) { hiveTable.copy( - schema = schemaFromTableProps, - partitionColumnNames = getPartitionColumnsFromTableProperties(table), + schema = reorderedSchema, + partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table)) } else { // Hive metastore may change the table schema, e.g. schema inference. If the table @@ -738,11 +794,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) + val schemaFromTableProps = getSchemaFromTableProperties(table) + val partColumnNames = getPartitionColumnsFromTableProperties(table) + val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) + table.copy( provider = Some(provider), storage = storageWithLocation, - schema = getSchemaFromTableProperties(table), - partitionColumnNames = getPartitionColumnsFromTableProperties(table), + schema = reorderedSchema, + partitionColumnNames = partColumnNames, bucketSpec = getBucketSpecFromTableProperties(table), tracksPartitionsInCatalog = partitionProvider == Some(TABLE_PARTITION_PROVIDER_CATALOG)) } @@ -1030,9 +1090,19 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat table: String, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = withClient { val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table)) - client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => + val res = client.getPartitions(db, table, partialSpec.map(lowerCasePartitionSpec)).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) } + + partialSpec match { + // This might be a bug of Hive: When the partition value inside the partial partition spec + // contains dot, and we ask Hive to list partitions w.r.t. the partial partition spec, Hive + // treats dot as matching any single character and may return more partitions than we + // expected. Here we do an extra filter to drop unexpected partitions. + case Some(spec) if spec.exists(_._2.contains(".")) => + res.filter(p => isPartialPartitionSpec(spec, p.spec)) + case _ => res + } } override def listPartitionsByFilter( @@ -1056,7 +1126,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override def createFunction( + override protected def doCreateFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1069,12 +1139,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override def dropFunction(db: String, name: String): Unit = withClient { + override protected def doDropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = withClient { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) client.renameFunction(db, oldName, newName) @@ -1095,6 +1168,43 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.listFunctions(db, pattern) } + /** + * Allows the user to pre-process table names before the Hive metastore is looked up. This can + * be used to encode additional information into the table name, such as a version number + * (e.g. `mytable_v1`, `mytable_v2`, etc.) + * @param newTableNamePreprocessor a function to be applied to Hive table name before we look up + * the table in the Hive metastore. + */ + def setTableNamePreprocessor(newTableNamePreprocessor: (String) => String): Unit = { + tableNamePreprocessor = newTableNamePreprocessor + } + + def getTableNamePreprocessor: (String) => String = tableNamePreprocessor + + /** + * Allows to register a custom way to select files/directories to be included in a table scan + * based on the table name. This can be used together with [[setTableNamePreprocessor]] to + * customize table scan results based on the specified table name. E.g. `mytable_v1` could have a + * different set of files than `mytable_v2`, and both of these "virtual tables" would be backed + * by a real Hive table `mytable`. Note that the table name passed to the user-provided file + * selection method is the name specified in the query, not the table name in the Hive metastore + * that is generated by applying the user-specified "table name preprocessor" method. + * @param hadoopFileSelector the user Hadoop file selection strategy + * @see [[setTableNamePreprocessor]] + */ + def setHadoopFileSelector(hadoopFileSelector: HadoopFileSelector): Unit = { + this.hadoopFileSelector = Some(hadoopFileSelector) + } + + /** + * Removes the "Hadoop file selector" strategy that was installed using the + * [[setHadoopFileSelector]] method. + */ + def unsetHadoopFileSelector(): Unit = { + hadoopFileSelector = None + } + + override def findHadoopFileSelector: Option[HadoopFileSelector] = hadoopFileSelector } object HiveExternalCatalog { @@ -1193,4 +1303,14 @@ object HiveExternalCatalog { getColumnNamesByType(metadata.properties, "sort", "sorting columns")) } } + + /** + * Detects a data source table. This checks both the table provider and the table properties, + * unlike DDLUtils which just checks the former. + */ + private[spark] def isDatasourceTable(table: CatalogTable): Boolean = { + val provider = table.provider.orElse(table.properties.get(DATASOURCE_PROVIDER)) + provider.isDefined && provider != Some(DDLUtils.HIVE_PROVIDER) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b98066cb76c8..a87aa4877a128 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { // these are def_s and not val/lazy val since the latter would introduce circular references private def sessionState = sparkSession.sessionState - private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + private def catalogProxy = sparkSession.sessionState.catalog import HiveMetastoreCatalog._ /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ @@ -61,7 +61,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val key = QualifiedTableName( table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, table.table.toLowerCase) - tableRelationCache.getIfPresent(key) + catalogProxy.getCachedTable(key) } private def getCached( @@ -71,7 +71,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log expectedFileFormat: Class[_ <: FileFormat], partitionSchema: Option[StructType]): Option[LogicalRelation] = { - tableRelationCache.getIfPresent(tableIdentifier) match { + catalogProxy.getCachedTable(tableIdentifier) match { case null => None // Cache miss case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => val cachedRelationFileFormatClass = relation.fileFormat.getClass @@ -92,27 +92,27 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(logical) } else { // If the cached relation is not updated, we invalidate it right away. - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case _ => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a ${relation.fileFormat} from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case other => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a $other from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } } def convertToLogicalRelation( - relation: CatalogRelation, + relation: HiveTableRelation, options: Map[String, String], fileFormatClass: Class[_ <: FileFormat], fileType: String): LogicalRelation = { @@ -176,7 +176,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormat = fileFormat, options = options)(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } @@ -184,10 +184,16 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log }) } else { val rootPath = tablePath + val paths: Seq[Path] = + if (fileType != "parquet") { + Seq(rootPath) + } else { + selectParquetLocationDirectories(relation.tableMeta.identifier.table, Option(rootPath)) + } withTableCreationLock(tableIdentifier, { val cached = getCached( tableIdentifier, - Seq(rootPath), + paths, metastoreSchema, fileFormatClass, None) @@ -197,7 +203,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log LogicalRelation( DataSource( sparkSession = sparkSession, - paths = rootPath.toString :: Nil, + paths = paths.map(_.toString), userSpecifiedSchema = Option(dataSchema), // We don't support hive bucketed tables, only ones we write out. bucketSpec = None, @@ -205,14 +211,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log className = fileType).resolveRelation(), table = updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } logicalRelation }) } - // The inferred schema may have different filed names as the table schema, we should respect + // The inferred schema may have different field names as the table schema, we should respect // it, but also respect the exprId in table relation output. assert(result.output.length == relation.output.length && result.output.zip(relation.output).forall { case (a1, a2) => a1.dataType == a2.dataType }) @@ -222,8 +228,28 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log result.copy(output = newOutput) } + private[hive] def selectParquetLocationDirectories( + tableName: String, + locationOpt: Option[Path]): Seq[Path] = { + val hadoopConf = sparkSession.sparkContext.hadoopConfiguration + val paths: Option[Seq[Path]] = for { + selector <- sparkSession.sharedState.externalCatalog.findHadoopFileSelector + location <- locationOpt + fs = location.getFileSystem(hadoopConf) + selectedPaths <- selector.selectFiles(tableName, fs, location) + selectedDir = for { + selectedPath <- selectedPaths + if selectedPath + .getFileSystem(hadoopConf) + .isDirectory(selectedPath) + } yield selectedPath + if selectedDir.nonEmpty + } yield selectedDir + paths.getOrElse(Seq(locationOpt.orNull)) + } + private def inferIfNeeded( - relation: CatalogRelation, + relation: HiveTableRelation, options: Map[String, String], fileFormat: FileFormat, fileIndexOpt: Option[FileIndex] = None): (StructType, CatalogTable) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 377d4f2473c58..6227e780c0409 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -140,7 +140,7 @@ private[sql] class HiveSessionCatalog( // Hive is case insensitive. val functionName = funcName.unquotedString.toLowerCase(Locale.ROOT) if (!hiveFunctions.contains(functionName)) { - failFunctionLookup(funcName.unquotedString) + failFunctionLookup(funcName) } // TODO: Remove this fallback path once we implement the list of fallback functions @@ -148,12 +148,12 @@ private[sql] class HiveSessionCatalog( val functionInfo = { try { Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( - failFunctionLookup(funcName.unquotedString)) + failFunctionLookup(funcName)) } catch { // If HiveFunctionRegistry.getFunctionInfo throws an exception, // we are failing to load a Hive builtin function, which means that // the given function is not a Hive builtin function. - case NonFatal(e) => failFunctionLookup(funcName.unquotedString) + case NonFatal(e) => failFunctionLookup(funcName) } } val className = functionInfo.getFunctionClass.getName diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 9d3b31f39c0f5..e16c9e46b7723 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -101,7 +101,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ Seq( FileSourceStrategy, - DataSourceStrategy, + DataSourceStrategy(conf), SpecialLimits, InMemoryScans, HiveTableScans, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 09a5eda6e543f..53e500ea78fc4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics, CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation} @@ -116,7 +116,7 @@ class ResolveHiveSerdeTable(session: SparkSession) extends Rule[LogicalPlan] { class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case relation: CatalogRelation + case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta // TODO: check if this estimate is valid for tables after partition pruning. @@ -160,9 +160,9 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { */ object HiveAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists) - if DDLUtils.isHiveTable(relation.tableMeta) => - InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists) + case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) + if DDLUtils.isHiveTable(r.tableMeta) => + InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -184,13 +184,13 @@ object HiveAnalysis extends Rule[LogicalPlan] { case class RelationConversions( conf: SQLConf, sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] { - private def isConvertible(relation: CatalogRelation): Boolean = { + private def isConvertible(relation: HiveTableRelation): Boolean = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) || serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC) } - private def convert(relation: CatalogRelation): LogicalRelation = { + private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { val options = Map(ParquetOptions.MERGE_SCHEMA -> @@ -207,14 +207,14 @@ case class RelationConversions( override def apply(plan: LogicalPlan): LogicalPlan = { plan transformUp { // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) + case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && isConvertible(r) => - InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists) + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && isConvertible(r) => + InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists) // Read path - case relation: CatalogRelation + case relation: HiveTableRelation if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) => convert(relation) } @@ -242,7 +242,7 @@ private[hive] trait HiveStrategies { */ object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: CatalogRelation) => + case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. val partitionKeyIds = AttributeSet(relation.partitionCols) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 16c1103dd1ea3..69ed2db223469 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -65,7 +67,10 @@ class HadoopTableReader( @transient private val tableDesc: TableDesc, @transient private val sparkSession: SparkSession, hadoopConf: Configuration) - extends TableReader with Logging { + extends TableReader with CastSupport with Logging { + + private val emptyStringsAsNulls = + sparkSession.conf.get("spark.sql.emptyStringsAsNulls", "false").toBoolean // Hadoop honors "mapreduce.job.maps" as hint, // but will ignore when mapreduce.jobtracker.address is "local". @@ -86,6 +91,8 @@ class HadoopTableReader( private val _broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + override def conf: SQLConf = sparkSession.sessionState.conf + override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( hiveTable, @@ -115,21 +122,32 @@ class HadoopTableReader( val broadcastedHadoopConf = _broadcastedHadoopConf val tablePath = hiveTable.getPath - val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt) + val fs = tablePath.getFileSystem(hadoopConf) + val externalCatalog = sparkSession.sharedState.externalCatalog + val inputPaths: Seq[String] = { externalCatalog match { + case hiveExternalCatalog: HiveExternalCatalog => + hiveExternalCatalog.hadoopFileSelector.flatMap( + _.selectFiles(hiveTable.getTableName, fs, tablePath) + ).map(_.map(_.toString)) + case _ => None + } + }.getOrElse(applyFilterIfNeeded(tablePath, filterOpt)) // logDebug("Table input: %s".format(tablePath)) val ifc = hiveTable.getInputFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] - val hadoopRDD = createHadoopRdd(localTableDesc, inputPathStr, ifc) + val hadoopRDD = createHadoopRdd(localTableDesc, inputPaths, ifc) val attrsWithIndex = attributes.zipWithIndex val mutableRow = new SpecificInternalRow(attributes.map(_.dataType)) + val localEmptyStringsAsNulls = emptyStringsAsNulls // for serializability val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHadoopConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, localTableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer) + HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow, deserializer, + localEmptyStringsAsNulls) } deserializedHadoopRDD @@ -227,7 +245,7 @@ class HadoopTableReader( def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = { partitionKeyAttrs.foreach { case (attr, ordinal) => val partOrdinal = partitionKeys.indexOf(attr) - row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) + row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } @@ -236,6 +254,7 @@ class HadoopTableReader( val tableProperties = tableDesc.getProperties + val localEmptyStringsAsNulls = emptyStringsAsNulls // for serializability // Create local references so that the outer object isn't serialized. val localTableDesc = tableDesc createHadoopRdd(localTableDesc, inputPathStr, ifc).mapPartitions { iter => @@ -257,7 +276,7 @@ class HadoopTableReader( // fill the non partition key attributes HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, - mutableRow, tableSerDe) + mutableRow, tableSerDe, localEmptyStringsAsNulls) } }.toSeq @@ -273,13 +292,12 @@ class HadoopTableReader( * If `filterOpt` is defined, then it will be used to filter files from `path`. These files are * returned in a single, comma-separated string. */ - private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { + private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): Seq[String] = { filterOpt match { case Some(filter) => val fs = path.getFileSystem(hadoopConf) - val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) - filteredFiles.mkString(",") - case None => path.toString + fs.listStatus(path, filter).map(_.getPath.toString) + case None => Seq(path.toString) } } @@ -289,10 +307,10 @@ class HadoopTableReader( */ private def createHadoopRdd( tableDesc: TableDesc, - path: String, + paths: Seq[String], inputFormatClass: Class[InputFormat[Writable, Writable]]): RDD[Writable] = { - val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ + val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(paths, tableDesc) _ val rdd = new HadoopRDD( sparkSession.sparkContext, @@ -337,8 +355,8 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to * instantiate a HadoopRDD. */ - def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { - FileInputFormat.setInputPaths(jobConf, Seq[Path](new Path(path)): _*) + def initializeLocalJobConfFunc(paths: Seq[String], tableDesc: TableDesc)(jobConf: JobConf) { + FileInputFormat.setInputPaths(jobConf, paths.map { pathStr => new Path(pathStr) }: _*) if (tableDesc != null) { HiveTableUtil.configureJobPropertiesForStorageHandler(tableDesc, jobConf, true) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) @@ -356,6 +374,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { * positions in the output schema * @param mutableRow A reusable `MutableRow` that should be filled * @param tableDeser Table Deserializer + * @param emptyStringsAsNulls whether to treat empty strings as nulls * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( @@ -363,7 +382,8 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { rawDeser: Deserializer, nonPartitionKeyAttrs: Seq[(Attribute, Int)], mutableRow: InternalRow, - tableDeser: Deserializer): Iterator[InternalRow] = { + tableDeser: Deserializer, + emptyStringsAsNulls: Boolean): Iterator[InternalRow] = { val soi = if (rawDeser.getObjectInspector.equals(tableDeser.getObjectInspector)) { rawDeser.getObjectInspector.asInstanceOf[StructObjectInspector] @@ -399,9 +419,27 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { (value: Any, row: InternalRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) case oi: DoubleObjectInspector => (value: Any, row: InternalRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi: HiveVarcharObjectInspector if emptyStringsAsNulls => + (value: Any, row: InternalRow, ordinal: Int) => { + val strValue = UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue) + if (strValue == UTF8String.EMPTY_UTF8) { + row.update(ordinal, null) + } else { + row.update(ordinal, strValue) + } + } case oi: HiveVarcharObjectInspector => (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) + case oi: StringObjectInspector if emptyStringsAsNulls => + (value: Any, row: InternalRow, ordinal: Int) => { + val strValue = UTF8String.fromString(oi.getPrimitiveJavaObject(value)) + if (strValue == UTF8String.EMPTY_UTF8) { + row.update(ordinal, null) + } else { + row.update(ordinal, strValue) + } + } case oi: HiveCharObjectInspector => (value: Any, row: InternalRow, ordinal: Int) => row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 387ec4f967233..2cf11f41a10da 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -22,7 +22,6 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.language.reflectiveCalls import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -48,6 +47,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.client.HiveClientImpl._ import org.apache.spark.sql.types._ import org.apache.spark.util.{CircularBuffer, Utils} @@ -839,7 +839,7 @@ private[hive] object HiveClientImpl { } // after SPARK-19279, it is not allowed to create a hive table with an empty schema, // so here we should not add a default col schema - if (schema.isEmpty && DDLUtils.isDatasourceTable(table)) { + if (schema.isEmpty && HiveExternalCatalog.isDatasourceTable(table)) { // This is a hack to preserve existing behavior. Before Spark 2.0, we do not // set a default serde here (this was done in Hive), and so if the user provides // an empty schema Hive would automatically populate the schema with a single diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index e95f9ea480431..b8aa067cdb903 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util -import scala.language.reflectiveCalls import scala.util.Try import org.apache.commons.io.{FileUtils, IOUtils} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 41c6b18e9d794..65e8b4e3c725c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -62,7 +62,7 @@ case class CreateHiveTableAsSelectCommand( Map(), query, overwrite = false, - ifNotExists = false)).toRdd + ifPartitionNotExists = false)).toRdd } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data @@ -78,7 +78,7 @@ case class CreateHiveTableAsSelectCommand( Map(), query, overwrite = true, - ifNotExists = false)).toRdd + ifPartitionNotExists = false)).toRdd } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 666548d1a490b..2ce8ccfb35e0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -30,13 +30,15 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClientImpl +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.util.Utils @@ -50,14 +52,16 @@ import org.apache.spark.util.Utils private[hive] case class HiveTableScanExec( requestedAttributes: Seq[Attribute], - relation: CatalogRelation, + relation: HiveTableRelation, partitionPruningPred: Seq[Expression])( @transient private val sparkSession: SparkSession) - extends LeafExecNode { + extends LeafExecNode with CastSupport { require(partitionPruningPred.isEmpty || relation.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + override def conf: SQLConf = sparkSession.sessionState.conf + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -104,7 +108,7 @@ case class HiveTableScanExec( hadoopConf) private def castFromString(value: String, dataType: DataType) = { - Cast(Literal(value), dataType).eval(null) + cast(Literal(value), dataType).eval(null) } private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { @@ -205,8 +209,8 @@ case class HiveTableScanExec( val input: AttributeSeq = relation.output HiveTableScanExec( requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), - relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) + relation.canonicalized.asInstanceOf[HiveTableRelation], + QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3682dc850790e..66ee5d4581e7e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.IOException +import java.io.{File, IOException} import java.net.URI import java.text.SimpleDateFormat import java.util.{Date, Locale, Random} @@ -71,14 +71,15 @@ import org.apache.spark.SparkException * }}}. * @param query the logical plan representing data to write to. * @param overwrite overwrite existing table or partitions. - * @param ifNotExists If true, only write if the table or partition does not exist. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. */ case class InsertIntoHiveTable( table: CatalogTable, partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) extends RunnableCommand { + ifPartitionNotExists: Boolean) extends RunnableCommand { override protected def innerChildren: Seq[LogicalPlan] = query :: Nil @@ -97,12 +98,24 @@ case class InsertIntoHiveTable( val inputPathUri: URI = inputPath.toUri val inputPathName: String = inputPathUri.getPath val fs: FileSystem = inputPath.getFileSystem(hadoopConf) - val stagingPathName: String = + var stagingPathName: String = if (inputPathName.indexOf(stagingDir) == -1) { new Path(inputPathName, stagingDir).toString } else { inputPathName.substring(0, inputPathName.indexOf(stagingDir) + stagingDir.length) } + + // SPARK-20594: This is a walk-around fix to resolve a Hive bug. Hive requires that the + // staging directory needs to avoid being deleted when users set hive.exec.stagingdir + // under the table directory. + if (FileUtils.isSubDir(new Path(stagingPathName), inputPath, fs) && + !stagingPathName.stripPrefix(inputPathName).stripPrefix(File.separator).startsWith(".")) { + logDebug(s"The staging dir '$stagingPathName' should be a child directory starts " + + "with '.' to avoid being deleted if we set hive.exec.stagingdir under the table " + + "directory.") + stagingPathName = new Path(inputPathName, ".hive-staging").toString + } + val dir: Path = fs.makeQualified( new Path(stagingPathName + "_" + executionId + "-" + TaskRunner.getTaskRunnerID)) @@ -301,13 +314,6 @@ case class InsertIntoHiveTable( outputPath = tmpLocation.toString, isAppend = false) - val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name => - query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]") - }.asInstanceOf[Attribute] - } - FileFormatWriter.write( sparkSession = sparkSession, queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, @@ -315,7 +321,7 @@ case class InsertIntoHiveTable( committer = committer, outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), hadoopConf = hadoopConf, - partitionColumns = partitionAttributes, + partitionColumnNames = partitionColumnNames.takeRight(numDynamicPartitions), bucketSpec = None, refreshFunction = _ => (), options = Map.empty) @@ -342,7 +348,7 @@ case class InsertIntoHiveTable( var doHiveOverwrite = overwrite - if (oldPart.isEmpty || !ifNotExists) { + if (oldPart.isEmpty || !ifPartitionNotExists) { // SPARK-18107: Insert overwrite runs much slower than hive-client. // Newer Hive largely improves insert overwrite performance. As Spark uses older Hive // version and we may not want to catch up new Hive version every time. We delete the @@ -387,7 +393,13 @@ case class InsertIntoHiveTable( // Attempt to delete the staging directory and the inclusive files. If failed, the files are // expected to be dropped at the normal termination of VM since deleteOnExit is used. try { - createdTempDir.foreach { path => path.getFileSystem(hadoopConf).delete(path, true) } + createdTempDir.foreach { path => + val fs = path.getFileSystem(hadoopConf) + if (fs.delete(path, true)) { + // If we successfully delete the staging directory, remove it from FileSystem's cache. + fs.cancelDeleteOnExit(path) + } + } } catch { case NonFatal(e) => logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index d9bb1f8c7edcc..4612cce80effd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.test import java.io.File +import java.net.URI import java.util.{Set => JavaSet} import scala.collection.JavaConverters._ @@ -486,16 +487,16 @@ private[hive] class TestHiveSparkSession( } } + // Clean out the Hive warehouse between each suite + val warehouseDir = new File(new URI(sparkContext.conf.get("spark.sql.warehouse.dir")).getPath) + Utils.deleteRecursively(warehouseDir) + warehouseDir.mkdir() + sharedState.cacheManager.clearCache() loadedTables.clear() - sessionState.catalog.clearTempTables() - sessionState.catalog.tableRelationCache.invalidateAll() - + sessionState.catalog.reset() metadataHive.reset() - FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). - foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } - // HDFS root scratch dir requires the write all (733) permission. For each connecting user, // an HDFS scratch dir: ${hive.exec.scratchdir}/ is created, with // ${hive.scratch.dir.permission}. To resolve the permission issue, the simplest way is to diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala deleted file mode 100644 index 705d43f1f3aba..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI - -import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - - -class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest - with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { - - // To test `HiveExternalCatalog`, we need to read/write the raw table meta from/to hive client. - val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - - val tempDir = Utils.createTempDir().getCanonicalFile - val tempDirUri = tempDir.toURI - val tempDirStr = tempDir.getAbsolutePath - - override def beforeEach(): Unit = { - sql("CREATE DATABASE test_db") - for ((tbl, _) <- rawTablesAndExpectations) { - hiveClient.createTable(tbl, ignoreIfExists = false) - } - } - - override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) - hiveClient.dropDatabase("test_db", ignoreIfNotExists = false, cascade = true) - } - - private def getTableMetadata(tableName: String): CatalogTable = { - spark.sharedState.externalCatalog.getTable("test_db", tableName) - } - - private def defaultTableURI(tableName: String): URI = { - spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) - } - - // Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark - // versions prior to 2.1 would generate almost same raw table metadata for a specific table. - val simpleSchema = new StructType().add("i", "int") - val partitionedSchema = new StructType().add("i", "int").add("j", "int") - - lazy val hiveTable = CatalogTable( - identifier = TableIdentifier("tbl1", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - schema = simpleSchema) - - lazy val externalHiveTable = CatalogTable( - identifier = TableIdentifier("tbl2", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(tempDirUri), - inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - schema = simpleSchema) - - lazy val partitionedHiveTable = CatalogTable( - identifier = TableIdentifier("tbl3", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), - schema = partitionedSchema, - partitionColumnNames = Seq("j")) - - - val simpleSchemaJson = - """ - |{ - | "type": "struct", - | "fields": [{ - | "name": "i", - | "type": "integer", - | "nullable": true, - | "metadata": {} - | }] - |} - """.stripMargin - - val partitionedSchemaJson = - """ - |{ - | "type": "struct", - | "fields": [{ - | "name": "i", - | "type": "integer", - | "nullable": true, - | "metadata": {} - | }, - | { - | "name": "j", - | "type": "integer", - | "nullable": true, - | "metadata": {} - | }] - |} - """.stripMargin - - lazy val dataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl4", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> defaultTableURI("tbl4").toString)), - schema = new StructType(), - provider = Some("json"), - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val hiveCompatibleDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl5", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> defaultTableURI("tbl5").toString)), - schema = simpleSchema, - provider = Some("parquet"), - properties = Map( - "spark.sql.sources.provider" -> "parquet", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val partitionedDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl6", Some("test_db")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> defaultTableURI("tbl6").toString)), - schema = new StructType(), - provider = Some("json"), - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> partitionedSchemaJson, - "spark.sql.sources.schema.numPartCols" -> "1", - "spark.sql.sources.schema.partCol.0" -> "j")) - - lazy val externalDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl7", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")), - properties = Map("path" -> tempDirStr)), - schema = new StructType(), - provider = Some("json"), - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val hiveCompatibleExternalDataSourceTable = CatalogTable( - identifier = TableIdentifier("tbl8", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(tempDirUri), - properties = Map("path" -> tempDirStr)), - schema = simpleSchema, - properties = Map( - "spark.sql.sources.provider" -> "parquet", - "spark.sql.sources.schema.numParts" -> "1", - "spark.sql.sources.schema.part.0" -> simpleSchemaJson)) - - lazy val dataSourceTableWithoutSchema = CatalogTable( - identifier = TableIdentifier("tbl9", Some("test_db")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")), - properties = Map("path" -> tempDirStr)), - schema = new StructType(), - provider = Some("json"), - properties = Map("spark.sql.sources.provider" -> "json")) - - // A list of all raw tables we want to test, with their expected schema. - lazy val rawTablesAndExpectations = Seq( - hiveTable -> simpleSchema, - externalHiveTable -> simpleSchema, - partitionedHiveTable -> partitionedSchema, - dataSourceTable -> simpleSchema, - hiveCompatibleDataSourceTable -> simpleSchema, - partitionedDataSourceTable -> partitionedSchema, - externalDataSourceTable -> simpleSchema, - hiveCompatibleExternalDataSourceTable -> simpleSchema, - dataSourceTableWithoutSchema -> new StructType()) - - test("make sure we can read table created by old version of Spark") { - for ((tbl, expectedSchema) <- rawTablesAndExpectations) { - val readBack = getTableMetadata(tbl.identifier.table) - assert(readBack.schema.sameType(expectedSchema)) - - if (tbl.tableType == CatalogTableType.EXTERNAL) { - // trim the URI prefix - val tableLocation = readBack.storage.locationUri.get.getPath - val expectedLocation = tempDir.toURI.getPath.stripSuffix("/") - assert(tableLocation == expectedLocation) - } - } - } - - test("make sure we can alter table location created by old version of Spark") { - withTempDir { dir => - for ((tbl, _) <- rawTablesAndExpectations if tbl.tableType == CatalogTableType.EXTERNAL) { - val path = dir.toURI.toString.stripSuffix("/") - sql(s"ALTER TABLE ${tbl.identifier} SET LOCATION '$path'") - - val readBack = getTableMetadata(tbl.identifier.table) - - // trim the URI prefix - val actualTableLocation = readBack.storage.locationUri.get.getPath - val expected = dir.toURI.getPath.stripSuffix("/") - assert(actualTableLocation == expected) - } - } - } - - test("make sure we can rename table created by old version of Spark") { - for ((tbl, expectedSchema) <- rawTablesAndExpectations) { - val newName = tbl.identifier.table + "_renamed" - sql(s"ALTER TABLE ${tbl.identifier} RENAME TO $newName") - - val readBack = getTableMetadata(newName) - assert(readBack.schema.sameType(expectedSchema)) - - // trim the URI prefix - val actualTableLocation = readBack.storage.locationUri.get.getPath - val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) { - tempDir.toURI.getPath.stripSuffix("/") - } else { - // trim the URI prefix - defaultTableURI(newName).getPath - } - assert(actualTableLocation == expectedLocation) - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index bd54c043c6ec4..d43534d5914d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -63,4 +63,30 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { assert(!rawTable.properties.contains(HiveExternalCatalog.DATASOURCE_PROVIDER)) assert(DDLUtils.isHiveTable(externalCatalog.getTable("db1", "hive_tbl"))) } + + Seq("parquet", "hive").foreach { format => + test(s"Partition columns should be put at the end of table schema for the format $format") { + val catalog = newBasicCatalog() + val newSchema = new StructType() + .add("col1", "int") + .add("col2", "string") + .add("partCol1", "int") + .add("partCol2", "string") + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType() + .add("col1", "int") + .add("partCol1", "int") + .add("partCol2", "string") + .add("col2", "string"), + provider = Some(format), + partitionColumnNames = Seq("partCol1", "partCol2")) + catalog.createTable(table, ignoreIfExists = false) + + val restoredTable = externalCatalog.getTable("db1", "tbl") + assert(restoredTable.schema == newSchema) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala new file mode 100644 index 0000000000000..305f5b533d592 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File +import java.nio.file.Files + +import org.apache.spark.TestUtils +import org.apache.spark.sql.{QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +/** + * Test HiveExternalCatalog backward compatibility. + * + * Note that, this test suite will automatically download spark binary packages of different + * versions to a local directory `/tmp/spark-test`. If there is already a spark folder with + * expected version under this local directory, e.g. `/tmp/spark-test/spark-2.0.3`, we will skip the + * downloading for this spark version. + */ +class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { + private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") + private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") + // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to + // avoid downloading Spark of different versions in each run. + private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") + private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + + override def afterAll(): Unit = { + Utils.deleteRecursively(wareHousePath) + Utils.deleteRecursively(tmpDataDir) + Utils.deleteRecursively(sparkTestingDir) + super.afterAll() + } + + private def downloadSpark(version: String): Unit = { + import scala.sys.process._ + + val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" + + Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! + + val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + + Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + + Seq("rm", downloaded).! + } + + private def genDataDir(name: String): String = { + new File(tmpDataDir, name).getCanonicalPath + } + + override def beforeAll(): Unit = { + super.beforeAll() + + val tempPyFile = File.createTempFile("test", ".py") + Files.write(tempPyFile.toPath, + s""" + |from pyspark.sql import SparkSession + | + |spark = SparkSession.builder.enableHiveSupport().getOrCreate() + |version_index = spark.conf.get("spark.sql.test.version.index", None) + | + |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index)) + | + |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\ + | " using parquet as select 1 i") + | + |json_file = "${genDataDir("json_")}" + str(version_index) + |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file) + |spark.sql("create table external_data_source_tbl_" + version_index + \\ + | "(i int) using json options (path '{}')".format(json_file)) + | + |parquet_file = "${genDataDir("parquet_")}" + str(version_index) + |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file) + |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\ + | "(i int) using parquet options (path '{}')".format(parquet_file)) + | + |json_file2 = "${genDataDir("json2_")}" + str(version_index) + |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2) + |spark.sql("create table external_table_without_schema_" + version_index + \\ + | " using json options (path '{}')".format(json_file2)) + | + |spark.sql("create view v_{} as select 1 i".format(version_index)) + """.stripMargin.getBytes("utf8")) + + PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => + val sparkHome = new File(sparkTestingDir, s"spark-$version") + if (!sparkHome.exists()) { + downloadSpark(version) + } + + val args = Seq( + "--name", "prepare testing tables", + "--master", "local[2]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}", + "--conf", s"spark.sql.test.version.index=$index", + "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", + tempPyFile.getCanonicalPath) + runSparkSubmit(args, Some(sparkHome.getCanonicalPath)) + } + + tempPyFile.delete() + } + + test("backward compatibility") { + val args = Seq( + "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"), + "--name", "HiveExternalCatalog backward compatibility test", + "--master", "local[2]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}", + "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}", + unusedJar.toString) + runSparkSubmit(args) + } +} + +object PROCESS_TABLES extends QueryTest with SQLTestUtils { + // Tests the latest version of every release line. + val testingVersions = Seq("2.0.2", "2.1.1", "2.2.0") + + protected var spark: SparkSession = _ + + def main(args: Array[String]): Unit = { + val session = SparkSession.builder() + .enableHiveSupport() + .getOrCreate() + spark = session + + testingVersions.indices.foreach { index => + Seq( + s"data_source_tbl_$index", + s"hive_compatible_data_source_tbl_$index", + s"external_data_source_tbl_$index", + s"hive_compatible_external_data_source_tbl_$index", + s"external_table_without_schema_$index").foreach { tbl => + val tableMeta = spark.sharedState.externalCatalog.getTable("default", tbl) + + // make sure we can insert and query these tables. + session.sql(s"insert into $tbl select 2") + checkAnswer(session.sql(s"select * from $tbl"), Row(1) :: Row(2) :: Nil) + checkAnswer(session.sql(s"select i from $tbl where i > 1"), Row(2)) + + // make sure we can rename table. + val newName = tbl + "_renamed" + sql(s"ALTER TABLE $tbl RENAME TO $newName") + val readBack = spark.sharedState.externalCatalog.getTable("default", newName) + + val actualTableLocation = readBack.storage.locationUri.get.getPath + val expectedLocation = if (tableMeta.tableType == CatalogTableType.EXTERNAL) { + tableMeta.storage.locationUri.get.getPath + } else { + spark.sessionState.catalog.defaultTablePath(TableIdentifier(newName, None)).getPath + } + assert(actualTableLocation == expectedLocation) + + // make sure we can alter table location. + withTempDir { dir => + val path = dir.toURI.toString.stripSuffix("/") + sql(s"ALTER TABLE ${tbl}_renamed SET LOCATION '$path'") + val readBack = spark.sharedState.externalCatalog.getTable("default", tbl + "_renamed") + val actualTableLocation = readBack.storage.locationUri.get.getPath + val expected = dir.toURI.getPath.stripSuffix("/") + assert(actualTableLocation == expected) + } + } + + // test permanent view + checkAnswer(sql(s"select i from v_$index"), Row(1)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index d8fd68b63d1eb..c58060754f793 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive +import java.io.File + +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -182,3 +185,86 @@ class DataSourceWithHiveMetastoreCatalogSuite } } } + +class ParquetLocationSelectionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.HadoopFileSelector + import org.apache.spark.sql.hive.test.TestHive + private val hmc = new HiveMetastoreCatalog(spark) + // ensuring temp directories + private val baseDir = { + val base = + File.createTempFile( + "selectParquetLocationDirectories", + "1", + TestHive.sparkSession.hiveFilesTemp) + base.delete() + base.mkdirs() + base + } + + test(s"With Selector selecting from ${baseDir.toString}") { + val fullpath = { (somewhere: String, sometable: String) => + s"${baseDir.toString}/$somewhere/$sometable" + } + spark.sharedState.externalCatalog.setHadoopFileSelector(new HadoopFileSelector() { + override def selectFiles( + sometable: String, + fs: FileSystem, + somewhere: Path): Option[Seq[Path]] = { + Some(Seq(new Path(fullpath(somewhere.toString, sometable)))) + } + }) + + // ensure directory existence for somewhere/sometable + val somewhereSometable = new File(fullpath("somewhere", "sometable")) + somewhereSometable.mkdirs() + // somewhere/sometable is a directory => will be selected + assertResult(Seq(new Path(fullpath("somewhere", "sometable")))) { + hmc.selectParquetLocationDirectories("sometable", Option(new Path("somewhere"))) + } + + // ensure file existence for somewhere/sometable + somewhereSometable.delete() + somewhereSometable.createNewFile() + // somewhere/sometable is a file => will not be selected + assertResult(Seq(new Path("somewhere"))) { + hmc.selectParquetLocationDirectories("otherplace", Option(new Path("somewhere"))) + } + + // no location specified, none selected + assertResult(Seq(null)) { + hmc.selectParquetLocationDirectories("sometable", Option(null)) + } + } + + test("With Selector selecting None") { + spark.sharedState.externalCatalog.setHadoopFileSelector(new HadoopFileSelector() { + override def selectFiles( + tableName: String, + fs: FileSystem, + basePath: Path): Option[Seq[Path]] = None + }) + + // none selected + assertResult(Seq(new Path("somewhere"))) { + hmc.selectParquetLocationDirectories("sometable", Option(new Path("somewhere"))) + } + // none selected + assertResult(Seq(null)) { + hmc.selectParquetLocationDirectories("sometable", Option(null)) + } + } + + test("Without Selector") { + spark.sharedState.externalCatalog.unsetHadoopFileSelector() + + // none selected + assertResult(Seq(new Path("somewhere"))) { + hmc.selectParquetLocationDirectories("sometable", Option(new Path("somewhere"))) + } + // none selected + assertResult(Seq(null)) { + hmc.selectParquetLocationDirectories("sometable", Option(null)) + } + } +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index 319d02613f00a..d271acc63de08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -46,7 +46,7 @@ class HiveSchemaInferenceSuite override def afterEach(): Unit = { super.afterEach() - spark.sessionState.catalog.tableRelationCache.invalidateAll() + spark.sessionState.catalog.invalidateAllCachedTables() FileStatusCache.resetForTesting() } @@ -104,7 +104,7 @@ class HiveSchemaInferenceSuite identifier = TableIdentifier(table = TEST_TABLE_NAME, database = Option(DATABASE)), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat( - locationUri = Option(new java.net.URI(dir.getAbsolutePath)), + locationUri = Option(dir.toURI), inputFormat = serde.inputFormat, outputFormat = serde.outputFormat, serde = serde.serde, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 5f15a705a2e99..cf145c845eef0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -18,17 +18,11 @@ package org.apache.spark.sql.hive import java.io.{BufferedWriter, File, FileWriter} -import java.sql.Timestamp -import java.util.Date -import scala.collection.mutable.ArrayBuffer import scala.tools.nsc.Properties import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.scalatest.concurrent.Timeouts -import org.scalatest.exceptions.TestFailedDueToTimeoutException -import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging @@ -38,7 +32,6 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} -import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -46,11 +39,10 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} * This suite tests spark-submit with applications using HiveContext. */ class HiveSparkSubmitSuite - extends SparkFunSuite + extends SparkSubmitTestUtils with Matchers with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { + with ResetSystemProperties { // TODO: rewrite these or mark them as slow tests to be run sparingly @@ -333,71 +325,6 @@ class HiveSparkSubmitSuite unusedJar.toString) runSparkSubmit(argsForShowTables) } - - // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - // This is copied from org.apache.spark.deploy.SparkSubmitSuite - private def runSparkSubmit(args: Seq[String]): Unit = { - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val history = ArrayBuffer.empty[String] - val sparkSubmit = if (Utils.isWindows) { - // On Windows, `ProcessBuilder.directory` does not change the current working directory. - new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath - } else { - "./bin/spark-submit" - } - val commands = Seq(sparkSubmit) ++ args - val commandLine = commands.mkString("'", "' '", "'") - - val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) - val env = builder.environment() - env.put("SPARK_TESTING", "1") - env.put("SPARK_HOME", sparkHome) - - def captureOutput(source: String)(line: String): Unit = { - // This test suite has some weird behaviors when executed on Jenkins: - // - // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a - // timestamp to provide more diagnosis information. - // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print - // them out for debugging purposes. - val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" - // scalastyle:off println - println(logLine) - // scalastyle:on println - history += logLine - } - - val process = builder.start() - new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() - new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() - - try { - val exitCode = failAfter(300.seconds) { process.waitFor() } - if (exitCode != 0) { - // include logs in output. Note that logging is async and may not have completed - // at the time this exception is raised - Thread.sleep(1000) - val historyLog = history.mkString("\n") - fail { - s"""spark-submit returned with exit code $exitCode. - |Command line: $commandLine - | - |$historyLog - """.stripMargin - } - } - } catch { - case to: TestFailedDueToTimeoutException => - val historyLog = history.mkString("\n") - fail(s"Timeout of $commandLine" + - s" See the log4j logs for more detail." + - s"\n$historyLog", to) - case t: Throwable => throw t - } finally { - // Ensure we still kill the process in case it timed out - process.destroy() - } - } } object SetMetastoreURLTest extends Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d6999af84eac0..618e5b68ff8c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -166,72 +166,54 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE tmp_table") } - test("INSERT OVERWRITE - partition IF NOT EXISTS") { - withTempDir { tmpDir => - val table = "table_with_partition" - withTable(table) { - val selQuery = s"select c1, p1, p2 from $table" - sql( - s""" - |CREATE TABLE $table(c1 string) - |PARTITIONED by (p1 string,p2 string) - |location '${tmpDir.toURI.toString}' - """.stripMargin) - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') - |SELECT 'blarr' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr", "a", "b")) - - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') - |SELECT 'blarr2' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr2", "a", "b")) + testPartitionedTable("INSERT OVERWRITE - partition IF NOT EXISTS") { tableName => + val selQuery = s"select a, b, c, d from $tableName" + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) + |SELECT 1, 4 + """.stripMargin) + checkAnswer(sql(selQuery), Row(1, 2, 3, 4)) - var e = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2) IF NOT EXISTS - |SELECT 'blarr3', 'newPartition' - """.stripMargin) - } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) + |SELECT 5, 6 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) + + val e = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c) IF NOT EXISTS + |SELECT 7, 8, 3 + """.stripMargin) + } + assert(e.getMessage.contains( + "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [c]")) - e = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2) IF NOT EXISTS - |SELECT 'blarr3', 'b' - """.stripMargin) - } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + // If the partition already exists, the insert will overwrite the data + // unless users specify IF NOT EXISTS + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) IF NOT EXISTS + |SELECT 9, 10 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) - // If the partition already exists, the insert will overwrite the data - // unless users specify IF NOT EXISTS - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') IF NOT EXISTS - |SELECT 'blarr3' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr2", "a", "b")) - } - } + // ADD PARTITION has the same effect, even if no actual data is inserted. + sql(s"ALTER TABLE $tableName ADD PARTITION (b=21, c=31)") + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=21, c=31) IF NOT EXISTS + |SELECT 20, 24 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) } test("Insert ArrayType.containsNull == false") { @@ -486,6 +468,28 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } } + test("SPARK-21165: the query schema of INSERT is changed after optimization") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + withTable("tab1", "tab2") { + Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1") + + spark.sql( + """ + |CREATE TABLE tab2 (word string, length int) + |PARTITIONED BY (first string) + """.stripMargin) + + spark.sql( + """ + |INSERT INTO TABLE tab2 PARTITION(first) + |SELECT word, length, cast(first as string) as first FROM tab1 + """.stripMargin) + + checkAnswer(spark.table("tab2"), Row("a", 3, "b")) + } + } + } + testPartitionedTable("insertInto() should reject extra columns") { tableName => sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)") @@ -494,4 +498,15 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef spark.table("t").write.insertInto(tableName) } } + + test("SPARK-20594: hive.exec.stagingdir was deleted by Hive") { + // Set hive.exec.stagingdir under the table directory without start with ".". + withSQLConf("hive.exec.stagingdir" -> "./test") { + withTable("test_table") { + sql("CREATE TABLE test_table (key int)") + sql("INSERT OVERWRITE TABLE test_table SELECT 1") + checkAnswer(sql("SELECT * FROM test_table"), Row(1)) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index b554694815571..07d641d72e709 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1359,30 +1359,4 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sparkSession.sparkContext.conf.set(DEBUG_MODE, previousValue) } } - - test("SPARK-18464: support old table which doesn't store schema in table properties") { - withTable("old") { - withTempPath { path => - Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath) - val tableDesc = CatalogTable( - identifier = TableIdentifier("old", Some("default")), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy( - properties = Map("path" -> path.getAbsolutePath) - ), - schema = new StructType(), - provider = Some("parquet"), - properties = Map( - HiveExternalCatalog.DATASOURCE_PROVIDER -> "parquet")) - hiveClient.createTable(tableDesc, ignoreIfExists = false) - - checkAnswer(spark.table("old"), Row(1, "a")) - - val expectedSchema = StructType(Seq( - StructField("i", IntegerType, nullable = true), - StructField("j", StringType, nullable = true))) - assert(table("old").schema === expectedSchema) - } - } - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 43b6bf5feeb60..b2dc401ce1efc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.File +import java.sql.Timestamp import com.google.common.io.Files import org.apache.hadoop.fs.FileSystem @@ -68,4 +69,20 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl sql("DROP TABLE IF EXISTS createAndInsertTest") } } + + test("SPARK-21739: Cast expression should initialize timezoneId") { + withTable("table_with_timestamp_partition") { + sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)") + sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " + + "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)") + + // test for Cast expression in TableReader + checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"), + Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000")))) + + // test for Cast expression in HiveTableScanExec + checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " + + "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala new file mode 100644 index 0000000000000..4b28d4f362b80 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File +import java.sql.Timestamp +import java.util.Date + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.concurrent.Timeouts +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.util.Utils + +trait SparkSubmitTestUtils extends SparkFunSuite with Timeouts { + + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + // This is copied from org.apache.spark.deploy.SparkSubmitSuite + protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = { + val sparkHome = sparkHomeOpt.getOrElse( + sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))) + val history = ArrayBuffer.empty[String] + val sparkSubmit = if (Utils.isWindows) { + // On Windows, `ProcessBuilder.directory` does not change the current working directory. + new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath + } else { + "./bin/spark-submit" + } + val commands = Seq(sparkSubmit) ++ args + val commandLine = commands.mkString("'", "' '", "'") + + val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome)) + val env = builder.environment() + env.put("SPARK_TESTING", "1") + env.put("SPARK_HOME", sparkHome) + + def captureOutput(source: String)(line: String): Unit = { + // This test suite has some weird behaviors when executed on Jenkins: + // + // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a + // timestamp to provide more diagnosis information. + // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print + // them out for debugging purposes. + val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line" + // scalastyle:off println + println(logLine) + // scalastyle:on println + history += logLine + } + + val process = builder.start() + new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start() + new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() + + try { + val exitCode = failAfter(300.seconds) { process.waitFor() } + if (exitCode != 0) { + // include logs in output. Note that logging is async and may not have completed + // at the time this exception is raised + Thread.sleep(1000) + val historyLog = history.mkString("\n") + fail { + s"""spark-submit returned with exit code $exitCode. + |Command line: $commandLine + | + |$historyLog + """.stripMargin + } + } + } catch { + case to: TestFailedDueToTimeoutException => + val historyLog = history.mkString("\n") + fail(s"Timeout of $commandLine" + + s" See the log4j logs for more detail." + + s"\n$historyLog", to) + case t: Throwable => throw t + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3191b9975fbf9..a9caad897c589 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ @@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ + class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { test("Hive serde tables should fallback to HDFS for size estimation") { @@ -59,7 +60,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto |LOCATION '${tempDir.toURI}'""".stripMargin) val relation = spark.table("csv_table").queryExecution.analyzed.children.head - .asInstanceOf[CatalogRelation] + .asInstanceOf[HiveTableRelation] val properties = relation.tableMeta.properties assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") @@ -125,6 +126,77 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } + test("SPARK-21079 - analyze table with location different than that of individual partitions") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val tableName = "analyzeTable_part" + withTable(tableName) { + withTempPath { path => + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src") + } + + sql(s"ALTER TABLE $tableName SET LOCATION '$path'") + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + + assert(queryTotalSize(tableName) === BigInt(17436)) + } + } + } + + test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val sourceTableName = "analyzeTable_part" + val tableName = "analyzeTable_part_vis" + withTable(sourceTableName, tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a single top-level directory 'path' + sql( + s""" + |CREATE TABLE $sourceTableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql( + s""" + |INSERT INTO TABLE $sourceTableName PARTITION (ds='$ds') + |SELECT * FROM src + """.stripMargin) + } + + // Create another table referring to the same location + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + // Register only one of the partitions found on disk + val ds = partitionDates.head + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect() + + // Analyze original table - expect 3 partitions + sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812)) + + // Analyze partial-copy table - expect only 1 partition + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(tableName) === BigInt(5812)) + } + } + } + test("analyzing views is not supported") { def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { val err = intercept[AnalysisException] { @@ -145,23 +217,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } - private def checkTableStats( - tableName: String, - hasSizeInBytes: Boolean, - expectedRowCounts: Option[Int]): Option[CatalogStatistics] = { - val stats = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).stats - - if (hasSizeInBytes || expectedRowCounts.nonEmpty) { - assert(stats.isDefined) - assert(stats.get.sizeInBytes > 0) - assert(stats.get.rowCount === expectedRowCounts) - } else { - assert(stats.isEmpty) - } - - stats - } - test("test table-level statistics for hive tables created in HiveExternalCatalog") { val textTable = "textTable" withTable(textTable) { @@ -442,7 +497,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: HiveTableRelation => relation.stats(conf).sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -502,7 +557,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto () => (), metastoreQuery, metastoreAnswer, - implicitly[ClassTag[CatalogRelation]] + implicitly[ClassTag[HiveTableRelation]] ) } @@ -516,7 +571,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: HiveTableRelation => relation.stats(conf).sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3906968aaff10..6b19b5ccb6f24 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -32,10 +32,11 @@ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils // TODO(gatorsmile): combine HiveCatalogedDDLSuite and HiveDDLSuite class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeAndAfterEach { @@ -50,15 +51,28 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean): CatalogTable = { val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), - compressed = false, - properties = Map("serialization.format" -> "1")) + if (isDataSource) { + val serde = HiveSerDe.sourceToSerDe("parquet") + assert(serde.isDefined, "The default format is not Hive compatible") + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = serde.get.inputFormat, + outputFormat = serde.get.outputFormat, + serde = serde.get.serde, + compressed = false, + properties = Map("serialization.format" -> "1")) + } else { + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), + compressed = false, + properties = Map("serialization.format" -> "1")) + } val metadata = new MetadataBuilder() .putString("key", "value") .build() @@ -71,7 +85,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA .add("col2", "string") .add("a", "int") .add("b", "int"), - provider = Some("hive"), + provider = if (isDataSource) Some("parquet") else Some("hive"), partitionColumnNames = Seq("a", "b"), createTime = 0L, tracksPartitionsInCatalog = true) @@ -107,6 +121,46 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA ) } + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set properties") { + testSetProperties(isDatasourceTable = false) + } + + test("alter table: unset properties") { + testUnsetProperties(isDatasourceTable = false) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde partition") { + testSetSerdePartition(isDatasourceTable = false) + } + + test("alter table: change column") { + testChangeColumn(isDatasourceTable = false) + } + + test("alter table: rename partition") { + testRenamePartitions(isDatasourceTable = false) + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + } class HiveDDLSuite @@ -130,7 +184,7 @@ class HiveDDLSuite if (dbPath.isEmpty) { hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table) + new Path(new Path(dbPath.get), tableIdentifier.table).toUri } val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) @@ -732,7 +786,7 @@ class HiveDDLSuite checkAnswer( sql(s"DESC $tabName").select("col_name", "data_type", "comment"), - Row("# col_name", "data_type", "comment") :: Row("a", "int", "test") :: Nil + Row("a", "int", "test") :: Nil ) } } @@ -1197,6 +1251,14 @@ class HiveDDLSuite s"CREATE INDEX $indexName ON TABLE $tabName (a) AS 'COMPACT' WITH DEFERRED REBUILD") val indexTabName = spark.sessionState.catalog.listTables("default", s"*$indexName*").head.table + + // Even if index tables exist, listTables and getTable APIs should still work + checkAnswer( + spark.catalog.listTables().toDF(), + Row(indexTabName, "default", null, null, false) :: + Row(tabName, "default", null, "MANAGED", false) :: Nil) + assert(spark.catalog.getTable("default", indexTabName).name === indexTabName) + intercept[TableAlreadyExistsException] { sql(s"CREATE TABLE $indexTabName(b int)") } @@ -1575,7 +1637,7 @@ class HiveDDLSuite test("create hive table with a non-existing location") { withTable("t", "t1") { withTempPath { dir => - spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '$dir'") + spark.sql(s"CREATE TABLE t(a int, b int) USING hive LOCATION '${dir.toURI}'") val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) @@ -1592,7 +1654,7 @@ class HiveDDLSuite |CREATE TABLE t1(a int, b int) |USING hive |PARTITIONED BY(a) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -1620,7 +1682,7 @@ class HiveDDLSuite s""" |CREATE TABLE t |USING hive - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1636,7 +1698,7 @@ class HiveDDLSuite |CREATE TABLE t1 |USING hive |PARTITIONED BY(a, b) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -1662,21 +1724,21 @@ class HiveDDLSuite |CREATE TABLE t(a string, `$specialChars` string) |USING $datasource |PARTITIONED BY(`$specialChars`) - |LOCATION '$dir' + |LOCATION '${dir.toURI}' """.stripMargin) assert(dir.listFiles().isEmpty) spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" val partFile = new File(dir, partEscaped) - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1", "2") :: Nil) withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4") val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4" val partFile1 = new File(dir, partEscaped1) - assert(partFile1.listFiles().length >= 1) + assert(partFile1.listFiles().nonEmpty) checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil) } } @@ -1687,15 +1749,22 @@ class HiveDDLSuite Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"hive table: location uri contains $specialChars") { + // On Windows, it looks colon in the file name is illegal by default. See + // https://support.microsoft.com/en-us/help/289627 + assume(!Utils.isWindows || specialChars != "a:b") + withTable("t") { withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + // The parser does not recognize the backslashes on Windows as they are. + // These currently should be escaped. + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t(a string) |USING hive - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1718,12 +1787,13 @@ class HiveDDLSuite withTempDir { dir => val loc = new File(dir, specialChars) loc.mkdir() + val escapedLoc = loc.getAbsolutePath.replace("\\", "\\\\") spark.sql( s""" |CREATE TABLE t1(a string, b string) |USING hive |PARTITIONED BY(b) - |LOCATION '$loc' + |LOCATION '$escapedLoc' """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) @@ -1734,16 +1804,20 @@ class HiveDDLSuite if (specialChars != "a:b") { spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") val partFile = new File(loc, "b=2") - assert(partFile.listFiles().length >= 1) + assert(partFile.listFiles().nonEmpty) checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") assert(!partFile1.exists()) - val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") - assert(partFile2.listFiles().length >= 1) - checkAnswer(spark.table("t1"), - Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + + if (!Utils.isWindows) { + // Actual path becomes "b=2017-03-03%2012%3A13%253A14" on Windows. + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().nonEmpty) + checkAnswer(spark.table("t1"), + Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } } else { val e = intercept[AnalysisException] { spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 8a37bc3665d32..aa1ca2909074f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -43,11 +43,29 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("explain extended command") { checkKeywordsExist(sql(" explain select * from src where key=123 "), - "== Physical Plan ==") + "== Physical Plan ==", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + checkKeywordsNotExist(sql(" explain select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==") + "== Optimized Logical Plan ==", + "Owner", + "Database", + "Created", + "Last Access", + "Type", + "Provider", + "Properties", + "Statistics", + "Location", + "Serde Library", + "InputFormat", + "OutputFormat", + "Partition Provider", + "Schema" + ) + checkKeywordsExist(sql(" explain extended select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 90e037e292790..ae64cb3210b53 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -164,16 +164,30 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH |PARTITION (p1='a',p2='c',p3='c',p4='d',p5='e') |SELECT v.id """.stripMargin) - val plan = sql( - s""" - |SELECT * FROM $table - """.stripMargin).queryExecution.sparkPlan - val scan = plan.collectFirst { - case p: HiveTableScanExec => p - }.get + val scan = getHiveTableScanExec(s"SELECT * FROM $table") val numDataCols = scan.relation.dataCols.length scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols)) } } } + + test("HiveTableScanExec canonicalization for different orders of partition filters") { + val table = "hive_tbl_part" + withTable(table) { + sql( + s""" + |CREATE TABLE $table (id int) + |PARTITIONED BY (a int, b int) + """.stripMargin) + val scan1 = getHiveTableScanExec(s"SELECT * FROM $table WHERE a = 1 AND b = 2") + val scan2 = getHiveTableScanExec(s"SELECT * FROM $table WHERE b = 2 AND a = 1") + assert(scan1.sameResult(scan2)) + } + } + + private def getHiveTableScanExec(query: String): HiveTableScanExec = { + sql(query).queryExecution.sparkPlan.collectFirst { + case p: HiveTableScanExec => p + }.get + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 4446af2e75e00..8fcbad58350f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -590,6 +591,25 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("Call the function registered in the not-current database") { + Seq("true", "false").foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + withDatabase("dAtABaSe1") { + sql("CREATE DATABASE dAtABaSe1") + withUserDefinedFunction("dAtABaSe1.test_avg" -> false) { + sql(s"CREATE FUNCTION dAtABaSe1.test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer(sql("SELECT dAtABaSe1.test_avg(1)"), Row(1.0)) + } + val message = intercept[AnalysisException] { + sql("SELECT dAtABaSe1.unknownFunc(1)") + }.getMessage + assert(message.contains("Undefined function: 'unknownFunc'") && + message.contains("nor a permanent function registered in the database 'dAtABaSe1'")) + } + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala new file mode 100644 index 0000000000000..5c248b9acd04f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.language.existentials + +import org.apache.hadoop.conf.Configuration +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.sql.types._ +import org.apache.spark.tags.ExtendedHiveTest +import org.apache.spark.util.Utils + +/** + * A separate set of DDL tests that uses Hive 2.1 libraries, which behave a little differently + * from the built-in ones. + */ +@ExtendedHiveTest +class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with BeforeAndAfterEach + with BeforeAndAfterAll { + + // Create a custom HiveExternalCatalog instance with the desired configuration. We cannot + // use SparkSession here since there's already an active on managed by the TestHive object. + private var catalog = { + val warehouse = Utils.createTempDir() + val metastore = Utils.createTempDir() + metastore.delete() + val sparkConf = new SparkConf() + .set(SparkLauncher.SPARK_MASTER, "local") + .set(WAREHOUSE_PATH.key, warehouse.toURI().toString()) + .set(CATALOG_IMPLEMENTATION.key, "hive") + .set(HiveUtils.HIVE_METASTORE_VERSION.key, "2.1") + .set(HiveUtils.HIVE_METASTORE_JARS.key, "maven") + + val hadoopConf = new Configuration() + hadoopConf.set("hive.metastore.warehouse.dir", warehouse.toURI().toString()) + hadoopConf.set("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=${metastore.getAbsolutePath()};create=true") + // These options are needed since the defaults in Hive 2.1 cause exceptions with an + // empty metastore db. + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + hadoopConf.set("hive.metastore.schema.verification", "false") + + new HiveExternalCatalog(sparkConf, hadoopConf) + } + + override def afterEach: Unit = { + catalog.listTables("default").foreach { t => + catalog.dropTable("default", t, true, false) + } + spark.sessionState.catalog.reset() + } + + override def afterAll(): Unit = { + catalog = null + } + + test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) USING json", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))), + hiveCompatible = false) + } + + test("SPARK-21617: ALTER TABLE for Hive-compatible DataSource tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) USING parquet", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType)))) + } + + test("SPARK-21617: ALTER TABLE for Hive tables") { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 int) STORED AS parquet", + StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType)))) + } + + test("SPARK-21617: ALTER TABLE with incompatible schema on Hive-compatible table") { + val exception = intercept[AnalysisException] { + testAlterTable( + "t1", + "CREATE TABLE t1 (c1 string) USING parquet", + StructType(Array(StructField("c2", IntegerType)))) + } + assert(exception.getMessage().contains("types incompatible with the existing columns")) + } + + private def testAlterTable( + tableName: String, + createTableStmt: String, + updatedSchema: StructType, + hiveCompatible: Boolean = true): Unit = { + spark.sql(createTableStmt) + val oldTable = spark.sessionState.catalog.externalCatalog.getTable("default", tableName) + catalog.createTable(oldTable, true) + catalog.alterTableSchema("default", tableName, updatedSchema) + + val updatedTable = catalog.getTable("default", tableName) + assert(updatedTable.schema.fieldNames === updatedSchema.fieldNames) + } + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index f818e29555468..d91f25a4da013 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -66,4 +67,28 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te } } } + + test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") { + withTable("tbl") { + spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl") + sql(s"ANALYZE TABLE tbl COMPUTE STATISTICS") + val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats + assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost") + + val df = sql("SELECT * FROM tbl WHERE p = 1") + val sizes1 = df.queryExecution.analyzed.collect { + case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes + } + assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizes1(0) == tableStats.get.sizeInBytes) + + val relations = df.queryExecution.optimizedPlan.collect { + case relation: LogicalRelation => relation + } + assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") + val size2 = relations(0).computeStats(conf).sizeInBytes + assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) + assert(size2 < tableStats.get.sizeInBytes) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 75f3744ff35be..31b36f1574d16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.util.Locale +import java.util.{Locale, Set} import com.google.common.io.Files -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -454,7 +454,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { case LogicalRelation(r: HadoopFsRelation, _, _) => if (!isDataSourceTable) { fail( - s"${classOf[CatalogRelation].getCanonicalName} is expected, but found " + + s"${classOf[HiveTableRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") } userSpecifiedLocation match { @@ -464,11 +464,11 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } assert(catalogTable.provider.get === format) - case r: CatalogRelation => + case r: HiveTableRelation => if (isDataSourceTable) { fail( s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + - s"${classOf[CatalogRelation].getCanonicalName}.") + s"${classOf[HiveTableRelation].getCanonicalName}.") } userSpecifiedLocation match { case Some(location) => @@ -948,7 +948,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") { sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { - case SubqueryAlias(_, r: CatalogRelation) => // OK + case SubqueryAlias(_, r: HiveTableRelation) => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } @@ -1976,6 +1976,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("Auto alias construction of get_json_object") { + val df = Seq(("1", """{"f1": "value1", "f5": 5.23}""")).toDF("key", "jstring") + val expectedMsg = "Cannot create a table having a column whose name contains commas " + + "in Hive metastore. Table: `default`.`t`; Column: get_json_object(jstring, $.f1)" + + withTable("t") { + val e = intercept[AnalysisException] { + df.select($"key", functions.get_json_object($"jstring", "$.f1")) + .write.format("hive").saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + + withTempView("tempView") { + withTable("t") { + df.createTempView("tempView") + val e = intercept[AnalysisException] { + sql("CREATE TABLE t AS SELECT key, get_json_object(jstring, '$.f1') FROM tempView") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { withTable("spark_19912") { Seq( @@ -1991,4 +2015,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(table.filter($"p" === "p1\" and q=\"q1").select($"a"), Row(4)) } } + + test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { + val table = "test21721" + withTable(table) { + val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit") + deleteOnExitField.setAccessible(true) + + val fs = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]] + + val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() + sql(s"CREATE TABLE $table (key INT, value STRING)") + val pathSizeToDeleteOnExit = setOfPath.size() + + (0 to 10).foreach(_ => testData.write.mode(SaveMode.Append).insertInto(table)) + + assert(setOfPath.size() == pathSizeToDeleteOnExit) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index a20c758a83e71..3f9485dd018b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -232,31 +232,4 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) // scalastyle:on } - - test("null arguments") { - checkAnswer(sql(""" - |select p_mfgr, p_name, p_size, - |sum(null) over(distribute by p_mfgr sort by p_name) as sum, - |avg(null) over(distribute by p_mfgr sort by p_name) as avg - |from part - """.stripMargin), - sql(""" - |select p_mfgr, p_name, p_size, - |null as sum, - |null as avg - |from part - """.stripMargin)) - } - - test("SPARK-16646: LAST_VALUE(FALSE) OVER ()") { - checkAnswer(sql("SELECT LAST_VALUE(FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT LAST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT LAST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) - } - - test("SPARK-16646: FIRST_VALUE(FALSE) OVER ()") { - checkAnswer(sql("SELECT FIRST_VALUE(FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT FIRST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) - checkAnswer(sql("SELECT FIRST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 8c855730c31f2..60ccd996d6d58 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, RecordReaderIterator} import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHive._ @@ -475,7 +475,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } else { queryExecution.analyzed.collectFirst { - case _: CatalogRelation => () + case _: HiveTableRelation => () }.getOrElse { fail(s"Expecting no conversion from orc to data sources, " + s"but got:\n$queryExecution") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 23f21e6b9931e..303884da19f09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ @@ -812,7 +812,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } } else { queryExecution.analyzed.collectFirst { - case _: CatalogRelation => + case _: HiveTableRelation => }.getOrElse { fail(s"Expecting no conversion from parquet to data sources, " + s"but got:\n$queryExecution") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9f4009bfe402a..60a4638f610b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -103,7 +103,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { // `Cast`ed values are always of internal types (e.g. UTF8String instead of String) Cast(Literal(value), dataType).eval() }) - }.filter(predicate).map(projection) + }.filter(predicate.eval).map(projection) // Appends partition values val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes diff --git a/streaming/pom.xml b/streaming/pom.xml index de1be9c13e05f..681799a734a1e 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 5cbad8bf3ce6e..b8c780db07c98 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -55,6 +55,9 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.master", "spark.yarn.keytab", "spark.yarn.principal", + "spark.yarn.credentials.file", + "spark.yarn.credentials.renewalTime", + "spark.yarn.credentials.updateTime", "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index f55af6a5cc358..69e15655ad790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -304,7 +304,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } def render(request: HttpServletRequest): Seq[Node] = streamingListener.synchronized { - val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { + // stripXSS is called first to remove suspicious characters used in XSS attacks + val batchTime = + Option(SparkUIUtils.stripXSS(request.getParameter("id"))).map(id => Time(id.toLong)) + .getOrElse { throw new IllegalArgumentException(s"Missing id parameter") } val formattedBatchTime = diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 845f554308c43..1e5f18797e152 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -189,7 +189,9 @@ private[streaming] class FileBasedWriteAheadLog( val f = Future { deleteFile(logInfo) }(executionContext) if (waitForCompletion) { import scala.concurrent.duration._ + // scalastyle:off awaitready Await.ready(f, 1 second) + // scalastyle:on awaitready } } catch { case e: RejectedExecutionException => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index b70383ecde4d8..4f41b9d0a0b3c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ @@ -202,21 +201,17 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { test("block push errors are reported") { val listener = new TestBlockGeneratorListener { - @volatile var errorReported = false override def onPushBlock( blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { throw new SparkException("test") } - override def onError(message: String, throwable: Throwable): Unit = { - errorReported = true - } } blockGenerator = new BlockGenerator(listener, 0, conf) blockGenerator.start() - assert(listener.errorReported === false) + assert(listener.onErrorCalled === false) blockGenerator.addData(1) eventually(timeout(1 second), interval(10 milliseconds)) { - assert(listener.errorReported === true) + assert(listener.onErrorCalled === true) } blockGenerator.stop() } @@ -243,12 +238,15 @@ class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { @volatile var onGenerateBlockCalled = false @volatile var onAddDataCalled = false @volatile var onPushBlockCalled = false + @volatile var onErrorCalled = false override def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { pushedData.addAll(arrayBuffer.asJava) onPushBlockCalled = true } - override def onError(message: String, throwable: Throwable): Unit = {} + override def onError(message: String, throwable: Throwable): Unit = { + onErrorCalled = true + } override def onGenerateBlock(blockId: StreamBlockId): Unit = { onGenerateBlockCalled = true } diff --git a/tools/pom.xml b/tools/pom.xml index 938ba2f6ac201..c361b7fed8c8c 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.0-csd-1-SNAPSHOT ../pom.xml