diff --git a/.gitignore b/.gitignore index debad77ec2ad3..08f2d8f7543f0 100644 --- a/.gitignore +++ b/.gitignore @@ -74,3 +74,7 @@ metastore/ warehouse/ TempStatsStore/ sql/hive-thriftserver/test_warehouses + +# For R session data +.RHistory +.RData diff --git a/R/install-dev.bat b/R/install-dev.bat index 008a5c668bc45..ed1c91ae3a0ff 100644 --- a/R/install-dev.bat +++ b/R/install-dev.bat @@ -25,3 +25,9 @@ set SPARK_HOME=%~dp0.. MKDIR %SPARK_HOME%\R\lib R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\ + +rem Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd %SPARK_HOME%\R\lib +%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR +popd + diff --git a/R/install-dev.sh b/R/install-dev.sh index 59d98c9c7a646..4972bb9217072 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -42,4 +42,8 @@ Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtoo # Install SparkR to $LIB_DIR R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +cd $LIB_DIR +jar cfM "$LIB_DIR/sparkr.zip" SparkR + popd > /dev/null diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 3d6edb70ec98e..369714f7b99c2 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -34,4 +34,5 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'types.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 56b8ed0bf271b..2ee7d6f94f1bc 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -23,9 +23,11 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "as.data.frame", "attach", "cache", "collect", + "coltypes", "columns", "count", "cov", @@ -153,6 +155,7 @@ exportMethods("%in%", "isNaN", "isNotNull", "isNull", + "kurtosis", "lag", "last", "last_day", @@ -205,12 +208,17 @@ exportMethods("%in%", "shiftLeft", "shiftRight", "shiftRightUnsigned", + "sd", "sign", "signum", "sin", "sinh", "size", + "skewness", "soundex", + "stddev", + "stddev_pop", + "stddev_samp", "sqrt", "startsWith", "substr", @@ -229,6 +237,10 @@ exportMethods("%in%", "unhex", "unix_timestamp", "upper", + "var", + "variance", + "var_pop", + "var_samp", "weekofyear", "when", "year") @@ -262,6 +274,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") - -export("as.data.frame") + "print.structType") \ No newline at end of file diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 44ce9414da5cf..fd105ba5bc9bb 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -25,7 +25,7 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame #' @description DataFrames can be created using functions like \link{createDataFrame}, #' \link{jsonFile}, \link{table} etc. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname DataFrame #' @docType class #' @@ -68,7 +68,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname printSchema #' @name printSchema #' @export @@ -93,7 +93,7 @@ setMethod("printSchema", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname schema #' @name schema #' @export @@ -117,7 +117,7 @@ setMethod("schema", #' #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname explain #' @name explain #' @export @@ -148,7 +148,7 @@ setMethod("explain", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname isLocal #' @name isLocal #' @export @@ -173,7 +173,7 @@ setMethod("isLocal", #' @param x A SparkSQL DataFrame #' @param numRows The number of rows to print. Defaults to 20. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname showDF #' @name showDF #' @export @@ -198,7 +198,7 @@ setMethod("showDF", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname show #' @name show #' @export @@ -225,7 +225,7 @@ setMethod("show", "DataFrame", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname dtypes #' @name dtypes #' @export @@ -251,7 +251,7 @@ setMethod("dtypes", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname columns #' @name columns #' @aliases names @@ -272,7 +272,7 @@ setMethod("columns", }) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname columns #' @name names setMethod("names", @@ -281,7 +281,7 @@ setMethod("names", columns(x) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname columns #' @name names<- setMethod("names<-", @@ -300,7 +300,7 @@ setMethod("names<-", #' @param x A SparkSQL DataFrame #' @param tableName A character vector containing the name of the table #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname registerTempTable #' @name registerTempTable #' @export @@ -328,7 +328,7 @@ setMethod("registerTempTable", #' @param overwrite A logical argument indicating whether or not to overwrite #' the existing rows in the table. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname insertInto #' @name insertInto #' @export @@ -353,7 +353,7 @@ setMethod("insertInto", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname cache #' @name cache #' @export @@ -381,7 +381,7 @@ setMethod("cache", #' #' @param x The DataFrame to persist #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname persist #' @name persist #' @export @@ -409,7 +409,7 @@ setMethod("persist", #' @param x The DataFrame to unpersist #' @param blocking Whether to block until all blocks are deleted #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname unpersist-methods #' @name unpersist #' @export @@ -437,7 +437,7 @@ setMethod("unpersist", #' @param x A SparkSQL DataFrame #' @param numPartitions The number of partitions to use. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname repartition #' @name repartition #' @export @@ -456,25 +456,24 @@ setMethod("repartition", dataFrame(sdf) }) -# toJSON -# -# Convert the rows of a DataFrame into JSON objects and return an RDD where -# each element contains a JSON string. -# -# @param x A SparkSQL DataFrame -# @return A StringRRDD of JSON objects -# -# @family dataframe_funcs -# @rdname tojson -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# path <- "path/to/file.json" -# df <- jsonFile(sqlContext, path) -# newRDD <- toJSON(df) -#} +#' toJSON +#' +#' Convert the rows of a DataFrame into JSON objects and return an RDD where +#' each element contains a JSON string. +#' +#' @param x A SparkSQL DataFrame +#' @return A StringRRDD of JSON objects +#' @family DataFrame functions +#' @rdname tojson +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' newRDD <- toJSON(df) +#'} setMethod("toJSON", signature(x = "DataFrame"), function(x) { @@ -491,7 +490,7 @@ setMethod("toJSON", #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname saveAsParquetFile #' @name saveAsParquetFile #' @export @@ -515,7 +514,7 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname distinct #' @name distinct #' @export @@ -538,7 +537,7 @@ setMethod("distinct", # #' @description Returns a new DataFrame containing distinct rows in this DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname unique #' @name unique #' @aliases distinct @@ -556,7 +555,7 @@ setMethod("unique", #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname sample #' @aliases sample_frac #' @export @@ -580,7 +579,7 @@ setMethod("sample", dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname sample #' @name sample_frac setMethod("sample_frac", @@ -596,7 +595,7 @@ setMethod("sample_frac", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname count #' @name count #' @aliases nrow @@ -620,7 +619,7 @@ setMethod("count", #' #' @name nrow #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname nrow #' @aliases count setMethod("nrow", @@ -633,7 +632,7 @@ setMethod("nrow", #' #' @param x a SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname ncol #' @name ncol #' @export @@ -654,7 +653,7 @@ setMethod("ncol", #' Returns the dimentions (number of rows and columns) of a DataFrame #' @param x a SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname dim #' @name dim #' @export @@ -678,7 +677,7 @@ setMethod("dim", #' @param stringsAsFactors (Optional) A logical indicating whether or not string columns #' should be converted to factors. FALSE by default. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname collect #' @name collect #' @export @@ -746,7 +745,7 @@ setMethod("collect", #' @param num The number of rows to return #' @return A new DataFrame containing the number of rows specified. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname limit #' @name limit #' @export @@ -767,7 +766,7 @@ setMethod("limit", #' Take the first NUM rows of a DataFrame and return a the results as a data.frame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname take #' @name take #' @export @@ -796,7 +795,7 @@ setMethod("take", #' @param num The number of rows to return. Default is 6. #' @return A data.frame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname head #' @name head #' @export @@ -819,7 +818,7 @@ setMethod("head", #' #' @param x A SparkSQL DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname first #' @name first #' @export @@ -837,23 +836,21 @@ setMethod("first", take(x, 1) }) -# toRDD -# -# Converts a Spark DataFrame to an RDD while preserving column names. -# -# @param x A Spark DataFrame -# -# @family dataframe_funcs -# @rdname DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# path <- "path/to/file.json" -# df <- jsonFile(sqlContext, path) -# rdd <- toRDD(df) -# } +#' toRDD +#' +#' Converts a Spark DataFrame to an RDD while preserving column names. +#' +#' @param x A Spark DataFrame +#' +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlContext, path) +#' rdd <- toRDD(df) +#'} setMethod("toRDD", signature(x = "DataFrame"), function(x) { @@ -874,7 +871,7 @@ setMethod("toRDD", #' @return a GroupedData #' @seealso GroupedData #' @aliases group_by -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname groupBy #' @name groupBy #' @export @@ -899,7 +896,7 @@ setMethod("groupBy", groupedData(sgd) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname groupBy #' @name group_by setMethod("group_by", @@ -913,7 +910,7 @@ setMethod("group_by", #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname agg #' @name agg #' @aliases summarize @@ -924,7 +921,7 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname agg #' @name summarize setMethod("summarize", @@ -940,8 +937,8 @@ setMethod("summarize", # the requested map function. # ################################################################################### -# @family dataframe_funcs -# @rdname lapply +#' @rdname lapply +#' @noRd setMethod("lapply", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { @@ -949,24 +946,25 @@ setMethod("lapply", lapply(rdd, FUN) }) -# @family dataframe_funcs -# @rdname lapply +#' @rdname lapply +#' @noRd setMethod("map", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { lapply(X, FUN) }) -# @family dataframe_funcs -# @rdname flatMap +#' @rdname flatMap +#' @noRd setMethod("flatMap", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { rdd <- toRDD(X) flatMap(rdd, FUN) }) -# @family dataframe_funcs -# @rdname lapplyPartition + +#' @rdname lapplyPartition +#' @noRd setMethod("lapplyPartition", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { @@ -974,16 +972,16 @@ setMethod("lapplyPartition", lapplyPartition(rdd, FUN) }) -# @family dataframe_funcs -# @rdname lapplyPartition +#' @rdname lapplyPartition +#' @noRd setMethod("mapPartitions", signature(X = "DataFrame", FUN = "function"), function(X, FUN) { lapplyPartition(X, FUN) }) -# @family dataframe_funcs -# @rdname foreach +#' @rdname foreach +#' @noRd setMethod("foreach", signature(x = "DataFrame", func = "function"), function(x, func) { @@ -991,8 +989,8 @@ setMethod("foreach", foreach(rdd, func) }) -# @family dataframe_funcs -# @rdname foreach +#' @rdname foreach +#' @noRd setMethod("foreachPartition", signature(x = "DataFrame", func = "function"), function(x, func) { @@ -1091,7 +1089,7 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @param select expression for the single Column or a list of columns to select from the DataFrame #' @return A new DataFrame containing only the rows that meet the condition with selected columns #' @export -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname subset #' @name subset #' @aliases [ @@ -1122,7 +1120,7 @@ setMethod("subset", signature(x = "DataFrame"), #' @param col A list of columns or single Column or name #' @return A new DataFrame with selected columns #' @export -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname select #' @name select #' @family subsetting functions @@ -1150,7 +1148,7 @@ setMethod("select", signature(x = "DataFrame", col = "character"), } }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname select #' @export setMethod("select", signature(x = "DataFrame", col = "Column"), @@ -1162,7 +1160,7 @@ setMethod("select", signature(x = "DataFrame", col = "Column"), dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname select #' @export setMethod("select", @@ -1187,7 +1185,7 @@ setMethod("select", #' @param expr A string containing a SQL expression #' @param ... Additional expressions #' @return A DataFrame -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname selectExpr #' @name selectExpr #' @export @@ -1215,7 +1213,7 @@ setMethod("selectExpr", #' @param colName A string containing the name of the new column. #' @param col A Column expression. #' @return A DataFrame with the new column added. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname withColumn #' @name withColumn #' @aliases mutate transform @@ -1241,7 +1239,7 @@ setMethod("withColumn", #' @param .data A DataFrame #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname withColumn #' @name mutate #' @aliases withColumn transform @@ -1275,7 +1273,7 @@ setMethod("mutate", }) #' @export -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname withColumn #' @name transform #' @aliases withColumn mutate @@ -1293,7 +1291,7 @@ setMethod("transform", #' @param existingCol The name of the column you want to change. #' @param newCol The new column name. #' @return A DataFrame with the column name changed. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname withColumnRenamed #' @name withColumnRenamed #' @export @@ -1325,7 +1323,7 @@ setMethod("withColumnRenamed", #' @param x A DataFrame #' @param newCol A named pair of the form new_column_name = existing_column #' @return A DataFrame with the column name changed. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname withColumnRenamed #' @name rename #' @aliases withColumnRenamed @@ -1370,7 +1368,7 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param decreasing A logical argument indicating sorting order for columns when #' a character vector is specified for col #' @return A DataFrame where all elements are sorted. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname arrange #' @name arrange #' @aliases orderby @@ -1397,7 +1395,7 @@ setMethod("arrange", dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname arrange #' @export setMethod("arrange", @@ -1429,7 +1427,7 @@ setMethod("arrange", do.call("arrange", c(x, jcols)) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname arrange #' @name orderby setMethod("orderBy", @@ -1446,7 +1444,7 @@ setMethod("orderBy", #' @param condition The condition to filter on. This may either be a Column expression #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname filter #' @name filter #' @family subsetting functions @@ -1470,7 +1468,7 @@ setMethod("filter", dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname filter #' @name where setMethod("where", @@ -1491,7 +1489,7 @@ setMethod("where", #' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', #' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname join #' @name join #' @export @@ -1550,7 +1548,7 @@ setMethod("join", #' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right #' outer join will be returned. If all.x and all.y are set to TRUE, a full #' outer join will be returned. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname merge #' @export #' @examples @@ -1682,7 +1680,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname unionAll #' @name unionAll #' @export @@ -1705,7 +1703,7 @@ setMethod("unionAll", #' #' @description Returns a new DataFrame containing rows of all parameters. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname rbind #' @name rbind #' @aliases unionAll @@ -1727,7 +1725,7 @@ setMethod("rbind", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the intersect. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname intersect #' @name intersect #' @export @@ -1754,7 +1752,7 @@ setMethod("intersect", #' @param x A Spark DataFrame #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the except operation. -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname except #' @name except #' @export @@ -1794,7 +1792,7 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname write.df #' @name write.df #' @aliases saveDF @@ -1830,7 +1828,7 @@ setMethod("write.df", callJMethod(df@sdf, "save", source, jmode, options) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname write.df #' @name saveDF #' @export @@ -1861,7 +1859,7 @@ setMethod("saveDF", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname saveAsTable #' @name saveAsTable #' @export @@ -1902,7 +1900,7 @@ setMethod("saveAsTable", #' @param col A string of name #' @param ... Additional expressions #' @return A DataFrame -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname describe #' @name describe #' @aliases summary @@ -1925,7 +1923,7 @@ setMethod("describe", dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname describe #' @name describe setMethod("describe", @@ -1940,13 +1938,13 @@ setMethod("describe", #' #' @description Computes statistics for numeric columns of the DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname summary #' @name summary setMethod("summary", - signature(x = "DataFrame"), - function(x) { - describe(x) + signature(object = "DataFrame"), + function(object, ...) { + describe(object) }) @@ -1965,7 +1963,7 @@ setMethod("summary", #' @param cols Optional list of column names to consider. #' @return A DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname nafunctions #' @name dropna #' @aliases na.omit @@ -1995,7 +1993,7 @@ setMethod("dropna", dataFrame(sdf) }) -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname nafunctions #' @name na.omit #' @export @@ -2023,7 +2021,7 @@ setMethod("na.omit", #' column is simply ignored. #' @return A DataFrame #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname nafunctions #' @name fillna #' @export @@ -2087,7 +2085,7 @@ setMethod("fillna", #' @title Download data from a DataFrame into a data.frame #' @param x a DataFrame #' @return a data.frame -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname as.data.frame #' @examples \dontrun{ #' @@ -2108,7 +2106,7 @@ setMethod("as.data.frame", #' the DataFrame is searched by R when evaluating a variable, so columns in #' the DataFrame can be accessed by simply giving their names. #' -#' @family dataframe_funcs +#' @family DataFrame functions #' @rdname attach #' @title Attach DataFrame to R search path #' @param what (DataFrame) The DataFrame to attach @@ -2152,3 +2150,52 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) + +#' Returns the column types of a DataFrame. +#' +#' @name coltypes +#' @title Get column types of a DataFrame +#' @family dataframe_funcs +#' @param x (DataFrame) +#' @return value (character) A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @examples \dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#' } +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) \ No newline at end of file diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 051e441d4e063..47945c2825da9 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -19,16 +19,15 @@ setOldClass("jobj") -# @title S4 class that represents an RDD -# @description RDD can be created using functions like -# \code{parallelize}, \code{textFile} etc. -# @rdname RDD -# @seealso parallelize, textFile -# -# @slot env An R environment that stores bookkeeping states of the RDD -# @slot jrdd Java object reference to the backing JavaRDD -# to an RDD -# @export +#' @title S4 class that represents an RDD +#' @description RDD can be created using functions like +#' \code{parallelize}, \code{textFile} etc. +#' @rdname RDD +#' @seealso parallelize, textFile +#' @slot env An R environment that stores bookkeeping states of the RDD +#' @slot jrdd Java object reference to the backing JavaRDD +#' to an RDD +#' @noRd setClass("RDD", slots = list(env = "environment", jrdd = "jobj")) @@ -111,14 +110,13 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) .Object }) -# @rdname RDD -# @export -# -# @param jrdd Java object reference to the backing JavaRDD -# @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD -# stores strings, and "row" if the RDD stores the rows of a DataFrame -# @param isCached TRUE if the RDD is cached -# @param isCheckpointed TRUE if the RDD has been checkpointed +#' @rdname RDD +#' @noRd +#' @param jrdd Java object reference to the backing JavaRDD +#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD +#' stores strings, and "row" if the RDD stores the rows of a DataFrame +#' @param isCached TRUE if the RDD is cached +#' @param isCheckpointed TRUE if the RDD has been checkpointed RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE, isCheckpointed = FALSE) { new("RDD", jrdd, serializedMode, isCached, isCheckpointed) @@ -201,19 +199,20 @@ setValidity("RDD", ############ Actions and Transformations ############ -# Persist an RDD -# -# Persist this RDD with the default storage level (MEMORY_ONLY). -# -# @param x The RDD to cache -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# cache(rdd) -#} -# @rdname cache-methods -# @aliases cache,RDD-method +#' Persist an RDD +#' +#' Persist this RDD with the default storage level (MEMORY_ONLY). +#' +#' @param x The RDD to cache +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) +#'} +#' @rdname cache-methods +#' @aliases cache,RDD-method +#' @noRd setMethod("cache", signature(x = "RDD"), function(x) { @@ -222,22 +221,23 @@ setMethod("cache", x }) -# Persist an RDD -# -# Persist this RDD with the specified storage level. For details of the -# supported storage levels, refer to -# http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. -# -# @param x The RDD to persist -# @param newLevel The new storage level to be assigned -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# persist(rdd, "MEMORY_AND_DISK") -#} -# @rdname persist -# @aliases persist,RDD-method +#' Persist an RDD +#' +#' Persist this RDD with the specified storage level. For details of the +#' supported storage levels, refer to +#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#' +#' @param x The RDD to persist +#' @param newLevel The new storage level to be assigned +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' persist(rdd, "MEMORY_AND_DISK") +#'} +#' @rdname persist +#' @aliases persist,RDD-method +#' @noRd setMethod("persist", signature(x = "RDD", newLevel = "character"), function(x, newLevel = "MEMORY_ONLY") { @@ -246,21 +246,22 @@ setMethod("persist", x }) -# Unpersist an RDD -# -# Mark the RDD as non-persistent, and remove all blocks for it from memory and -# disk. -# -# @param x The RDD to unpersist -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# cache(rdd) # rdd@@env$isCached == TRUE -# unpersist(rdd) # rdd@@env$isCached == FALSE -#} -# @rdname unpersist-methods -# @aliases unpersist,RDD-method +#' Unpersist an RDD +#' +#' Mark the RDD as non-persistent, and remove all blocks for it from memory and +#' disk. +#' +#' @param x The RDD to unpersist +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' cache(rdd) # rdd@@env$isCached == TRUE +#' unpersist(rdd) # rdd@@env$isCached == FALSE +#'} +#' @rdname unpersist-methods +#' @aliases unpersist,RDD-method +#' @noRd setMethod("unpersist", signature(x = "RDD"), function(x) { @@ -269,24 +270,25 @@ setMethod("unpersist", x }) -# Checkpoint an RDD -# -# Mark this RDD for checkpointing. It will be saved to a file inside the -# checkpoint directory set with setCheckpointDir() and all references to its -# parent RDDs will be removed. This function must be called before any job has -# been executed on this RDD. It is strongly recommended that this RDD is -# persisted in memory, otherwise saving it on a file will require recomputation. -# -# @param x The RDD to checkpoint -# @examples -#\dontrun{ -# sc <- sparkR.init() -# setCheckpointDir(sc, "checkpoint") -# rdd <- parallelize(sc, 1:10, 2L) -# checkpoint(rdd) -#} -# @rdname checkpoint-methods -# @aliases checkpoint,RDD-method +#' Checkpoint an RDD +#' +#' Mark this RDD for checkpointing. It will be saved to a file inside the +#' checkpoint directory set with setCheckpointDir() and all references to its +#' parent RDDs will be removed. This function must be called before any job has +#' been executed on this RDD. It is strongly recommended that this RDD is +#' persisted in memory, otherwise saving it on a file will require recomputation. +#' +#' @param x The RDD to checkpoint +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "checkpoint") +#' rdd <- parallelize(sc, 1:10, 2L) +#' checkpoint(rdd) +#'} +#' @rdname checkpoint-methods +#' @aliases checkpoint,RDD-method +#' @noRd setMethod("checkpoint", signature(x = "RDD"), function(x) { @@ -296,18 +298,19 @@ setMethod("checkpoint", x }) -# Gets the number of partitions of an RDD -# -# @param x A RDD. -# @return the number of partitions of rdd as an integer. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# numPartitions(rdd) # 2L -#} -# @rdname numPartitions -# @aliases numPartitions,RDD-method +#' Gets the number of partitions of an RDD +#' +#' @param x A RDD. +#' @return the number of partitions of rdd as an integer. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' numPartitions(rdd) # 2L +#'} +#' @rdname numPartitions +#' @aliases numPartitions,RDD-method +#' @noRd setMethod("numPartitions", signature(x = "RDD"), function(x) { @@ -316,24 +319,25 @@ setMethod("numPartitions", callJMethod(partitions, "size") }) -# Collect elements of an RDD -# -# @description -# \code{collect} returns a list that contains all of the elements in this RDD. -# -# @param x The RDD to collect -# @param ... Other optional arguments to collect -# @param flatten FALSE if the list should not flattened -# @return a list containing elements in the RDD -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2L) -# collect(rdd) # list from 1 to 10 -# collectPartition(rdd, 0L) # list from 1 to 5 -#} -# @rdname collect-methods -# @aliases collect,RDD-method +#' Collect elements of an RDD +#' +#' @description +#' \code{collect} returns a list that contains all of the elements in this RDD. +#' +#' @param x The RDD to collect +#' @param ... Other optional arguments to collect +#' @param flatten FALSE if the list should not flattened +#' @return a list containing elements in the RDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2L) +#' collect(rdd) # list from 1 to 10 +#' collectPartition(rdd, 0L) # list from 1 to 5 +#'} +#' @rdname collect-methods +#' @aliases collect,RDD-method +#' @noRd setMethod("collect", signature(x = "RDD"), function(x, flatten = TRUE) { @@ -344,12 +348,13 @@ setMethod("collect", }) -# @description -# \code{collectPartition} returns a list that contains all of the elements -# in the specified partition of the RDD. -# @param partitionId the partition to collect (starts from 0) -# @rdname collect-methods -# @aliases collectPartition,integer,RDD-method +#' @description +#' \code{collectPartition} returns a list that contains all of the elements +#' in the specified partition of the RDD. +#' @param partitionId the partition to collect (starts from 0) +#' @rdname collect-methods +#' @aliases collectPartition,integer,RDD-method +#' @noRd setMethod("collectPartition", signature(x = "RDD", partitionId = "integer"), function(x, partitionId) { @@ -362,17 +367,18 @@ setMethod("collectPartition", serializedMode = getSerializedMode(x)) }) -# @description -# \code{collectAsMap} returns a named list as a map that contains all of the elements -# in a key-value pair RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) -# collectAsMap(rdd) # list(`1` = 2, `3` = 4) -#} -# @rdname collect-methods -# @aliases collectAsMap,RDD-method +#' @description +#' \code{collectAsMap} returns a named list as a map that contains all of the elements +#' in a key-value pair RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) +#' collectAsMap(rdd) # list(`1` = 2, `3` = 4) +#'} +#' @rdname collect-methods +#' @aliases collectAsMap,RDD-method +#' @noRd setMethod("collectAsMap", signature(x = "RDD"), function(x) { @@ -382,19 +388,20 @@ setMethod("collectAsMap", as.list(map) }) -# Return the number of elements in the RDD. -# -# @param x The RDD to count -# @return number of elements in the RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# count(rdd) # 10 -# length(rdd) # Same as count -#} -# @rdname count -# @aliases count,RDD-method +#' Return the number of elements in the RDD. +#' +#' @param x The RDD to count +#' @return number of elements in the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' count(rdd) # 10 +#' length(rdd) # Same as count +#'} +#' @rdname count +#' @aliases count,RDD-method +#' @noRd setMethod("count", signature(x = "RDD"), function(x) { @@ -406,31 +413,32 @@ setMethod("count", sum(as.integer(vals)) }) -# Return the number of elements in the RDD -# @export -# @rdname count +#' Return the number of elements in the RDD +#' @rdname count +#' @noRd setMethod("length", signature(x = "RDD"), function(x) { count(x) }) -# Return the count of each unique value in this RDD as a list of -# (value, count) pairs. -# -# Same as countByValue in Spark. -# -# @param x The RDD to count -# @return list of (value, count) pairs, where count is number of each unique -# value in rdd. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, c(1,2,3,2,1)) -# countByValue(rdd) # (1,2L), (2,2L), (3,1L) -#} -# @rdname countByValue -# @aliases countByValue,RDD-method +#' Return the count of each unique value in this RDD as a list of +#' (value, count) pairs. +#' +#' Same as countByValue in Spark. +#' +#' @param x The RDD to count +#' @return list of (value, count) pairs, where count is number of each unique +#' value in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,3,2,1)) +#' countByValue(rdd) # (1,2L), (2,2L), (3,1L) +#'} +#' @rdname countByValue +#' @aliases countByValue,RDD-method +#' @noRd setMethod("countByValue", signature(x = "RDD"), function(x) { @@ -438,23 +446,24 @@ setMethod("countByValue", collect(reduceByKey(ones, `+`, numPartitions(x))) }) -# Apply a function to all elements -# -# This function creates a new RDD by applying the given transformation to all -# elements of the given RDD -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each element -# @return a new RDD created by the transformation. -# @rdname lapply -# @aliases lapply -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) -# collect(multiplyByTwo) # 2,4,6... -#} +#' Apply a function to all elements +#' +#' This function creates a new RDD by applying the given transformation to all +#' elements of the given RDD +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @rdname lapply +#' @noRd +#' @aliases lapply +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) +#' collect(multiplyByTwo) # 2,4,6... +#'} setMethod("lapply", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -464,31 +473,33 @@ setMethod("lapply", lapplyPartitionsWithIndex(X, func) }) -# @rdname lapply -# @aliases map,RDD,function-method +#' @rdname lapply +#' @aliases map,RDD,function-method +#' @noRd setMethod("map", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapply(X, FUN) }) -# Flatten results after apply a function to all elements -# -# This function return a new RDD by first applying a function to all -# elements of this RDD, and then flattening the results. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each element -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) -# collect(multiplyByTwo) # 2,20,4,40,6,60... -#} -# @rdname flatMap -# @aliases flatMap,RDD,function-method +#' Flatten results after apply a function to all elements +#' +#' This function return a new RDD by first applying a function to all +#' elements of this RDD, and then flattening the results. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each element +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) +#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#'} +#' @rdname flatMap +#' @aliases flatMap,RDD,function-method +#' @noRd setMethod("flatMap", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -501,83 +512,88 @@ setMethod("flatMap", lapplyPartition(X, partitionFunc) }) -# Apply a function to each partition of an RDD -# -# Return a new RDD by applying a function to each partition of this RDD. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each partition. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) -# collect(partitionSum) # 15, 40 -#} -# @rdname lapplyPartition -# @aliases lapplyPartition,RDD,function-method +#' Apply a function to each partition of an RDD +#' +#' Return a new RDD by applying a function to each partition of this RDD. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) +#' collect(partitionSum) # 15, 40 +#'} +#' @rdname lapplyPartition +#' @aliases lapplyPartition,RDD,function-method +#' @noRd setMethod("lapplyPartition", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) }) }) -# mapPartitions is the same as lapplyPartition. -# -# @rdname lapplyPartition -# @aliases mapPartitions,RDD,function-method +#' mapPartitions is the same as lapplyPartition. +#' +#' @rdname lapplyPartition +#' @aliases mapPartitions,RDD,function-method +#' @noRd setMethod("mapPartitions", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartition(X, FUN) }) -# Return a new RDD by applying a function to each partition of this RDD, while -# tracking the index of the original partition. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on each partition; takes the partition -# index and a list of elements in the particular partition. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 5L) -# prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { -# partIndex * Reduce("+", part) }) -# collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 -#} -# @rdname lapplyPartitionsWithIndex -# @aliases lapplyPartitionsWithIndex,RDD,function-method +#' Return a new RDD by applying a function to each partition of this RDD, while +#' tracking the index of the original partition. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on each partition; takes the partition +#' index and a list of elements in the particular partition. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 5L) +#' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { +#' partIndex * Reduce("+", part) }) +#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#'} +#' @rdname lapplyPartitionsWithIndex +#' @aliases lapplyPartitionsWithIndex,RDD,function-method +#' @noRd setMethod("lapplyPartitionsWithIndex", signature(X = "RDD", FUN = "function"), function(X, FUN) { PipelinedRDD(X, FUN) }) -# @rdname lapplyPartitionsWithIndex -# @aliases mapPartitionsWithIndex,RDD,function-method +#' @rdname lapplyPartitionsWithIndex +#' @aliases mapPartitionsWithIndex,RDD,function-method +#' @noRd setMethod("mapPartitionsWithIndex", signature(X = "RDD", FUN = "function"), function(X, FUN) { lapplyPartitionsWithIndex(X, FUN) }) -# This function returns a new RDD containing only the elements that satisfy -# a predicate (i.e. returning TRUE in a given logical function). -# The same as `filter()' in Spark. -# -# @param x The RDD to be filtered. -# @param f A unary predicate function. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) -#} -# @rdname filterRDD -# @aliases filterRDD,RDD,function-method +#' This function returns a new RDD containing only the elements that satisfy +#' a predicate (i.e. returning TRUE in a given logical function). +#' The same as `filter()' in Spark. +#' +#' @param x The RDD to be filtered. +#' @param f A unary predicate function. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#'} +#' @rdname filterRDD +#' @aliases filterRDD,RDD,function-method +#' @noRd setMethod("filterRDD", signature(x = "RDD", f = "function"), function(x, f) { @@ -587,30 +603,32 @@ setMethod("filterRDD", lapplyPartition(x, filter.func) }) -# @rdname filterRDD -# @aliases Filter +#' @rdname filterRDD +#' @aliases Filter +#' @noRd setMethod("Filter", signature(f = "function", x = "RDD"), function(f, x) { filterRDD(x, f) }) -# Reduce across elements of an RDD. -# -# This function reduces the elements of this RDD using the -# specified commutative and associative binary operator. -# -# @param x The RDD to reduce -# @param func Commutative and associative function to apply on elements -# of the RDD. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# reduce(rdd, "+") # 55 -#} -# @rdname reduce -# @aliases reduce,RDD,ANY-method +#' Reduce across elements of an RDD. +#' +#' This function reduces the elements of this RDD using the +#' specified commutative and associative binary operator. +#' +#' @param x The RDD to reduce +#' @param func Commutative and associative function to apply on elements +#' of the RDD. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' reduce(rdd, "+") # 55 +#'} +#' @rdname reduce +#' @aliases reduce,RDD,ANY-method +#' @noRd setMethod("reduce", signature(x = "RDD", func = "ANY"), function(x, func) { @@ -624,70 +642,74 @@ setMethod("reduce", Reduce(func, partitionList) }) -# Get the maximum element of an RDD. -# -# @param x The RDD to get the maximum element from -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# maximum(rdd) # 10 -#} -# @rdname maximum -# @aliases maximum,RDD +#' Get the maximum element of an RDD. +#' +#' @param x The RDD to get the maximum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' maximum(rdd) # 10 +#'} +#' @rdname maximum +#' @aliases maximum,RDD +#' @noRd setMethod("maximum", signature(x = "RDD"), function(x) { reduce(x, max) }) -# Get the minimum element of an RDD. -# -# @param x The RDD to get the minimum element from -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# minimum(rdd) # 1 -#} -# @rdname minimum -# @aliases minimum,RDD +#' Get the minimum element of an RDD. +#' +#' @param x The RDD to get the minimum element from +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' minimum(rdd) # 1 +#'} +#' @rdname minimum +#' @aliases minimum,RDD +#' @noRd setMethod("minimum", signature(x = "RDD"), function(x) { reduce(x, min) }) -# Add up the elements in an RDD. -# -# @param x The RDD to add up the elements in -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# sumRDD(rdd) # 55 -#} -# @rdname sumRDD -# @aliases sumRDD,RDD +#' Add up the elements in an RDD. +#' +#' @param x The RDD to add up the elements in +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' sumRDD(rdd) # 55 +#'} +#' @rdname sumRDD +#' @aliases sumRDD,RDD +#' @noRd setMethod("sumRDD", signature(x = "RDD"), function(x) { reduce(x, "+") }) -# Applies a function to all elements in an RDD, and force evaluation. -# -# @param x The RDD to apply the function -# @param func The function to be applied. -# @return invisible NULL. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# foreach(rdd, function(x) { save(x, file=...) }) -#} -# @rdname foreach -# @aliases foreach,RDD,function-method +#' Applies a function to all elements in an RDD, and force evaluation. +#' +#' @param x The RDD to apply the function +#' @param func The function to be applied. +#' @return invisible NULL. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreach(rdd, function(x) { save(x, file=...) }) +#'} +#' @rdname foreach +#' @aliases foreach,RDD,function-method +#' @noRd setMethod("foreach", signature(x = "RDD", func = "function"), function(x, func) { @@ -698,37 +720,39 @@ setMethod("foreach", invisible(collect(mapPartitions(x, partition.func))) }) -# Applies a function to each partition in an RDD, and force evaluation. -# -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# foreachPartition(rdd, function(part) { save(part, file=...); NULL }) -#} -# @rdname foreach -# @aliases foreachPartition,RDD,function-method +#' Applies a function to each partition in an RDD, and force evaluation. +#' +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' foreachPartition(rdd, function(part) { save(part, file=...); NULL }) +#'} +#' @rdname foreach +#' @aliases foreachPartition,RDD,function-method +#' @noRd setMethod("foreachPartition", signature(x = "RDD", func = "function"), function(x, func) { invisible(collect(mapPartitions(x, func))) }) -# Take elements from an RDD. -# -# This function takes the first NUM elements in the RDD and -# returns them in a list. -# -# @param x The RDD to take elements from -# @param num Number of elements to take -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# take(rdd, 2L) # list(1, 2) -#} -# @rdname take -# @aliases take,RDD,numeric-method +#' Take elements from an RDD. +#' +#' This function takes the first NUM elements in the RDD and +#' returns them in a list. +#' +#' @param x The RDD to take elements from +#' @param num Number of elements to take +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' take(rdd, 2L) # list(1, 2) +#'} +#' @rdname take +#' @aliases take,RDD,numeric-method +#' @noRd setMethod("take", signature(x = "RDD", num = "numeric"), function(x, num) { @@ -763,39 +787,40 @@ setMethod("take", }) -# First -# -# Return the first element of an RDD -# -# @rdname first -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# first(rdd) -# } +#' First +#' +#' Return the first element of an RDD +#' +#' @rdname first +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' first(rdd) +#' } +#' @noRd setMethod("first", signature(x = "RDD"), function(x) { take(x, 1)[[1]] }) -# Removes the duplicates from RDD. -# -# This function returns a new RDD containing the distinct elements in the -# given RDD. The same as `distinct()' in Spark. -# -# @param x The RDD to remove duplicates from. -# @param numPartitions Number of partitions to create. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, c(1,2,2,3,3,3)) -# sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) -#} -# @rdname distinct -# @aliases distinct,RDD-method +#' Removes the duplicates from RDD. +#' +#' This function returns a new RDD containing the distinct elements in the +#' given RDD. The same as `distinct()' in Spark. +#' +#' @param x The RDD to remove duplicates from. +#' @param numPartitions Number of partitions to create. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, c(1,2,2,3,3,3)) +#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#'} +#' @rdname distinct +#' @aliases distinct,RDD-method +#' @noRd setMethod("distinct", signature(x = "RDD"), function(x, numPartitions = SparkR:::numPartitions(x)) { @@ -807,24 +832,25 @@ setMethod("distinct", resRDD }) -# Return an RDD that is a sampled subset of the given RDD. -# -# The same as `sample()' in Spark. (We rename it due to signature -# inconsistencies with the `sample()' function in R's base package.) -# -# @param x The RDD to sample elements from -# @param withReplacement Sampling with replacement or not -# @param fraction The (rough) sample target fraction -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements -# collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates -#} -# @rdname sampleRDD -# @aliases sampleRDD,RDD +#' Return an RDD that is a sampled subset of the given RDD. +#' +#' The same as `sample()' in Spark. (We rename it due to signature +#' inconsistencies with the `sample()' function in R's base package.) +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#'} +#' @rdname sampleRDD +#' @aliases sampleRDD,RDD +#' @noRd setMethod("sampleRDD", signature(x = "RDD", withReplacement = "logical", fraction = "numeric", seed = "integer"), @@ -868,23 +894,24 @@ setMethod("sampleRDD", lapplyPartitionsWithIndex(x, samplingFunc) }) -# Return a list of the elements that are a sampled subset of the given RDD. -# -# @param x The RDD to sample elements from -# @param withReplacement Sampling with replacement or not -# @param num Number of elements to return -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:100) -# # exactly 5 elements sampled, which may not be distinct -# takeSample(rdd, TRUE, 5L, 1618L) -# # exactly 5 distinct elements sampled -# takeSample(rdd, FALSE, 5L, 16181618L) -#} -# @rdname takeSample -# @aliases takeSample,RDD +#' Return a list of the elements that are a sampled subset of the given RDD. +#' +#' @param x The RDD to sample elements from +#' @param withReplacement Sampling with replacement or not +#' @param num Number of elements to return +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:100) +#' # exactly 5 elements sampled, which may not be distinct +#' takeSample(rdd, TRUE, 5L, 1618L) +#' # exactly 5 distinct elements sampled +#' takeSample(rdd, FALSE, 5L, 16181618L) +#'} +#' @rdname takeSample +#' @aliases takeSample,RDD +#' @noRd setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", num = "integer", seed = "integer"), function(x, withReplacement, num, seed) { @@ -931,18 +958,19 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", base::sample(samples)[1:total] }) -# Creates tuples of the elements in this RDD by applying a function. -# -# @param x The RDD. -# @param func The function to be applied. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3)) -# collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) -#} -# @rdname keyBy -# @aliases keyBy,RDD +#' Creates tuples of the elements in this RDD by applying a function. +#' +#' @param x The RDD. +#' @param func The function to be applied. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3)) +#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#'} +#' @rdname keyBy +#' @aliases keyBy,RDD +#' @noRd setMethod("keyBy", signature(x = "RDD", func = "function"), function(x, func) { @@ -952,44 +980,46 @@ setMethod("keyBy", lapply(x, apply.func) }) -# Return a new RDD that has exactly numPartitions partitions. -# Can increase or decrease the level of parallelism in this RDD. Internally, -# this uses a shuffle to redistribute data. -# If you are decreasing the number of partitions in this RDD, consider using -# coalesce, which can avoid performing a shuffle. -# -# @param x The RDD. -# @param numPartitions Number of partitions to create. -# @seealso coalesce -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) -# numPartitions(rdd) # 4 -# numPartitions(repartition(rdd, 2L)) # 2 -#} -# @rdname repartition -# @aliases repartition,RDD +#' Return a new RDD that has exactly numPartitions partitions. +#' Can increase or decrease the level of parallelism in this RDD. Internally, +#' this uses a shuffle to redistribute data. +#' If you are decreasing the number of partitions in this RDD, consider using +#' coalesce, which can avoid performing a shuffle. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso coalesce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) +#' numPartitions(rdd) # 4 +#' numPartitions(repartition(rdd, 2L)) # 2 +#'} +#' @rdname repartition +#' @aliases repartition,RDD +#' @noRd setMethod("repartition", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { coalesce(x, numPartitions, TRUE) }) -# Return a new RDD that is reduced into numPartitions partitions. -# -# @param x The RDD. -# @param numPartitions Number of partitions to create. -# @seealso repartition -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) -# numPartitions(rdd) # 3 -# numPartitions(coalesce(rdd, 1L)) # 1 -#} -# @rdname coalesce -# @aliases coalesce,RDD +#' Return a new RDD that is reduced into numPartitions partitions. +#' +#' @param x The RDD. +#' @param numPartitions Number of partitions to create. +#' @seealso repartition +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L) +#' numPartitions(rdd) # 3 +#' numPartitions(coalesce(rdd, 1L)) # 1 +#'} +#' @rdname coalesce +#' @aliases coalesce,RDD +#' @noRd setMethod("coalesce", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { @@ -1013,19 +1043,20 @@ setMethod("coalesce", } }) -# Save this RDD as a SequenceFile of serialized objects. -# -# @param x The RDD to save -# @param path The directory where the file is saved -# @seealso objectFile -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# saveAsObjectFile(rdd, "/tmp/sparkR-tmp") -#} -# @rdname saveAsObjectFile -# @aliases saveAsObjectFile,RDD +#' Save this RDD as a SequenceFile of serialized objects. +#' +#' @param x The RDD to save +#' @param path The directory where the file is saved +#' @seealso objectFile +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsObjectFile +#' @aliases saveAsObjectFile,RDD +#' @noRd setMethod("saveAsObjectFile", signature(x = "RDD", path = "character"), function(x, path) { @@ -1038,18 +1069,19 @@ setMethod("saveAsObjectFile", invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path)) }) -# Save this RDD as a text file, using string representations of elements. -# -# @param x The RDD to save -# @param path The directory where the partitions of the text file are saved -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# saveAsTextFile(rdd, "/tmp/sparkR-tmp") -#} -# @rdname saveAsTextFile -# @aliases saveAsTextFile,RDD +#' Save this RDD as a text file, using string representations of elements. +#' +#' @param x The RDD to save +#' @param path The directory where the partitions of the text file are saved +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' saveAsTextFile(rdd, "/tmp/sparkR-tmp") +#'} +#' @rdname saveAsTextFile +#' @aliases saveAsTextFile,RDD +#' @noRd setMethod("saveAsTextFile", signature(x = "RDD", path = "character"), function(x, path) { @@ -1062,21 +1094,22 @@ setMethod("saveAsTextFile", callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path)) }) -# Sort an RDD by the given key function. -# -# @param x An RDD to be sorted. -# @param func A function used to compute the sort key for each element. -# @param ascending A flag to indicate whether the sorting is ascending or descending. -# @param numPartitions Number of partitions to create. -# @return An RDD where all elements are sorted. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(3, 2, 1)) -# collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) -#} -# @rdname sortBy -# @aliases sortBy,RDD,RDD-method +#' Sort an RDD by the given key function. +#' +#' @param x An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +#' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), function(x, func, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { @@ -1138,97 +1171,95 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList } -# Returns the first N elements from an RDD in ascending order. -# -# @param x An RDD. -# @param num Number of elements to return. -# @return The first N elements from the RDD in ascending order. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) -# takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) -#} -# @rdname takeOrdered -# @aliases takeOrdered,RDD,RDD-method +#' Returns the first N elements from an RDD in ascending order. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The first N elements from the RDD in ascending order. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) +#'} +#' @rdname takeOrdered +#' @aliases takeOrdered,RDD,RDD-method +#' @noRd setMethod("takeOrdered", signature(x = "RDD", num = "integer"), function(x, num) { takeOrderedElem(x, num) }) -# Returns the top N elements from an RDD. -# -# @param x An RDD. -# @param num Number of elements to return. -# @return The top N elements from the RDD. -# @rdname top -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) -# top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) -#} -# @rdname top -# @aliases top,RDD,RDD-method +#' Returns the top N elements from an RDD. +#' +#' @param x An RDD. +#' @param num Number of elements to return. +#' @return The top N elements from the RDD. +#' @rdname top +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) +#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) +#'} +#' @aliases top,RDD,RDD-method +#' @noRd setMethod("top", signature(x = "RDD", num = "integer"), function(x, num) { takeOrderedElem(x, num, FALSE) }) -# Fold an RDD using a given associative function and a neutral "zero value". -# -# Aggregate the elements of each partition, and then the results for all the -# partitions, using a given associative function and a neutral "zero value". -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param op An associative function for the folding operation. -# @return The folding result. -# @rdname fold -# @seealso reduce -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) -# fold(rdd, 0, "+") # 15 -#} -# @rdname fold -# @aliases fold,RDD,RDD-method +#' Fold an RDD using a given associative function and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using a given associative function and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param op An associative function for the folding operation. +#' @return The folding result. +#' @rdname fold +#' @seealso reduce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5)) +#' fold(rdd, 0, "+") # 15 +#'} +#' @aliases fold,RDD,RDD-method +#' @noRd setMethod("fold", signature(x = "RDD", zeroValue = "ANY", op = "ANY"), function(x, zeroValue, op) { aggregateRDD(x, zeroValue, op, op) }) -# Aggregate an RDD using the given combine functions and a neutral "zero value". -# -# Aggregate the elements of each partition, and then the results for all the -# partitions, using given combine functions and a neutral "zero value". -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param seqOp A function to aggregate the RDD elements. It may return a different -# result type from the type of the RDD elements. -# @param combOp A function to aggregate results of seqOp. -# @return The aggregation result. -# @rdname aggregateRDD -# @seealso reduce -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1, 2, 3, 4)) -# zeroValue <- list(0, 0) -# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } -# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -# aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) -#} -# @rdname aggregateRDD -# @aliases aggregateRDD,RDD,RDD-method +#' Aggregate an RDD using the given combine functions and a neutral "zero value". +#' +#' Aggregate the elements of each partition, and then the results for all the +#' partitions, using given combine functions and a neutral "zero value". +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the RDD elements. It may return a different +#' result type from the type of the RDD elements. +#' @param combOp A function to aggregate results of seqOp. +#' @return The aggregation result. +#' @rdname aggregateRDD +#' @seealso reduce +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1, 2, 3, 4)) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) +#'} +#' @aliases aggregateRDD,RDD,RDD-method +#' @noRd setMethod("aggregateRDD", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"), function(x, zeroValue, seqOp, combOp) { @@ -1241,25 +1272,24 @@ setMethod("aggregateRDD", Reduce(combOp, partitionList, zeroValue) }) -# Pipes elements to a forked external process. -# -# The same as 'pipe()' in Spark. -# -# @param x The RDD whose elements are piped to the forked external process. -# @param command The command to fork an external process. -# @param env A named list to set environment variables of the external process. -# @return A new RDD created by piping all elements to a forked external process. -# @rdname pipeRDD -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# collect(pipeRDD(rdd, "more") -# Output: c("1", "2", ..., "10") -#} -# @rdname pipeRDD -# @aliases pipeRDD,RDD,character-method +#' Pipes elements to a forked external process. +#' +#' The same as 'pipe()' in Spark. +#' +#' @param x The RDD whose elements are piped to the forked external process. +#' @param command The command to fork an external process. +#' @param env A named list to set environment variables of the external process. +#' @return A new RDD created by piping all elements to a forked external process. +#' @rdname pipeRDD +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' collect(pipeRDD(rdd, "more") +#' Output: c("1", "2", ..., "10") +#'} +#' @aliases pipeRDD,RDD,character-method +#' @noRd setMethod("pipeRDD", signature(x = "RDD", command = "character"), function(x, command, env = list()) { @@ -1274,42 +1304,40 @@ setMethod("pipeRDD", lapplyPartition(x, func) }) -# TODO: Consider caching the name in the RDD's environment -# Return an RDD's name. -# -# @param x The RDD whose name is returned. -# @rdname name -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1,2,3)) -# name(rdd) # NULL (if not set before) -#} -# @rdname name -# @aliases name,RDD +#' TODO: Consider caching the name in the RDD's environment +#' Return an RDD's name. +#' +#' @param x The RDD whose name is returned. +#' @rdname name +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' name(rdd) # NULL (if not set before) +#'} +#' @aliases name,RDD +#' @noRd setMethod("name", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "name") }) -# Set an RDD's name. -# -# @param x The RDD whose name is to be set. -# @param name The RDD name to be set. -# @return a new RDD renamed. -# @rdname setName -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(1,2,3)) -# setName(rdd, "myRDD") -# name(rdd) # "myRDD" -#} -# @rdname setName -# @aliases setName,RDD +#' Set an RDD's name. +#' +#' @param x The RDD whose name is to be set. +#' @param name The RDD name to be set. +#' @return a new RDD renamed. +#' @rdname setName +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(1,2,3)) +#' setName(rdd, "myRDD") +#' name(rdd) # "myRDD" +#'} +#' @aliases setName,RDD +#' @noRd setMethod("setName", signature(x = "RDD", name = "character"), function(x, name) { @@ -1317,25 +1345,26 @@ setMethod("setName", x }) -# Zip an RDD with generated unique Long IDs. -# -# Items in the kth partition will get ids k, n+k, 2*n+k, ..., where -# n is the number of partitions. So there may exist gaps, but this -# method won't trigger a spark job, which is different from -# zipWithIndex. -# -# @param x An RDD to be zipped. -# @return An RDD with zipped items. -# @seealso zipWithIndex -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithUniqueId(rdd)) -# # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) -#} -# @rdname zipWithUniqueId -# @aliases zipWithUniqueId,RDD +#' Zip an RDD with generated unique Long IDs. +#' +#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where +#' n is the number of partitions. So there may exist gaps, but this +#' method won't trigger a spark job, which is different from +#' zipWithIndex. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithIndex +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithUniqueId(rdd)) +#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) +#'} +#' @rdname zipWithUniqueId +#' @aliases zipWithUniqueId,RDD +#' @noRd setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { @@ -1354,28 +1383,29 @@ setMethod("zipWithUniqueId", lapplyPartitionsWithIndex(x, partitionFunc) }) -# Zip an RDD with its element indices. -# -# The ordering is first based on the partition index and then the -# ordering of items within each partition. So the first item in -# the first partition gets index 0, and the last item in the last -# partition receives the largest index. -# -# This method needs to trigger a Spark job when this RDD contains -# more than one partition. -# -# @param x An RDD to be zipped. -# @return An RDD with zipped items. -# @seealso zipWithUniqueId -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -# collect(zipWithIndex(rdd)) -# # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) -#} -# @rdname zipWithIndex -# @aliases zipWithIndex,RDD +#' Zip an RDD with its element indices. +#' +#' The ordering is first based on the partition index and then the +#' ordering of items within each partition. So the first item in +#' the first partition gets index 0, and the last item in the last +#' partition receives the largest index. +#' +#' This method needs to trigger a Spark job when this RDD contains +#' more than one partition. +#' +#' @param x An RDD to be zipped. +#' @return An RDD with zipped items. +#' @seealso zipWithUniqueId +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) +#' collect(zipWithIndex(rdd)) +#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) +#'} +#' @rdname zipWithIndex +#' @aliases zipWithIndex,RDD +#' @noRd setMethod("zipWithIndex", signature(x = "RDD"), function(x) { @@ -1407,20 +1437,21 @@ setMethod("zipWithIndex", lapplyPartitionsWithIndex(x, partitionFunc) }) -# Coalesce all elements within each partition of an RDD into a list. -# -# @param x An RDD. -# @return An RDD created by coalescing all elements within -# each partition into a list. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, as.list(1:4), 2L) -# collect(glom(rdd)) -# # list(list(1, 2), list(3, 4)) -#} -# @rdname glom -# @aliases glom,RDD +#' Coalesce all elements within each partition of an RDD into a list. +#' +#' @param x An RDD. +#' @return An RDD created by coalescing all elements within +#' each partition into a list. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, as.list(1:4), 2L) +#' collect(glom(rdd)) +#' # list(list(1, 2), list(3, 4)) +#'} +#' @rdname glom +#' @aliases glom,RDD +#' @noRd setMethod("glom", signature(x = "RDD"), function(x) { @@ -1433,21 +1464,22 @@ setMethod("glom", ############ Binary Functions ############# -# Return the union RDD of two RDDs. -# The same as union() in Spark. -# -# @param x An RDD. -# @param y An RDD. -# @return a new RDD created by performing the simple union (witout removing -# duplicates) of two input RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3) -# unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 -#} -# @rdname unionRDD -# @aliases unionRDD,RDD,RDD-method +#' Return the union RDD of two RDDs. +#' The same as union() in Spark. +#' +#' @param x An RDD. +#' @param y An RDD. +#' @return a new RDD created by performing the simple union (witout removing +#' duplicates) of two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3) +#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3 +#'} +#' @rdname unionRDD +#' @aliases unionRDD,RDD,RDD-method +#' @noRd setMethod("unionRDD", signature(x = "RDD", y = "RDD"), function(x, y) { @@ -1464,27 +1496,28 @@ setMethod("unionRDD", union.rdd }) -# Zip an RDD with another RDD. -# -# Zips this RDD with another one, returning key-value pairs with the -# first element in each RDD second element in each RDD, etc. Assumes -# that the two RDDs have the same number of partitions and the same -# number of elements in each partition (e.g. one was made through -# a map on the other). -# -# @param x An RDD to be zipped. -# @param other Another RDD to be zipped. -# @return An RDD zipped from the two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, 0:4) -# rdd2 <- parallelize(sc, 1000:1004) -# collect(zipRDD(rdd1, rdd2)) -# # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) -#} -# @rdname zipRDD -# @aliases zipRDD,RDD +#' Zip an RDD with another RDD. +#' +#' Zips this RDD with another one, returning key-value pairs with the +#' first element in each RDD second element in each RDD, etc. Assumes +#' that the two RDDs have the same number of partitions and the same +#' number of elements in each partition (e.g. one was made through +#' a map on the other). +#' +#' @param x An RDD to be zipped. +#' @param other Another RDD to be zipped. +#' @return An RDD zipped from the two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 0:4) +#' rdd2 <- parallelize(sc, 1000:1004) +#' collect(zipRDD(rdd1, rdd2)) +#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) +#'} +#' @rdname zipRDD +#' @aliases zipRDD,RDD +#' @noRd setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { @@ -1503,24 +1536,25 @@ setMethod("zipRDD", mergePartitions(rdd, TRUE) }) -# Cartesian product of this RDD and another one. -# -# Return the Cartesian product of this RDD and another one, -# that is, the RDD of all pairs of elements (a, b) where a -# is in this and b is in other. -# -# @param x An RDD. -# @param other An RDD. -# @return A new RDD which is the Cartesian product of these two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:2) -# sortByKey(cartesian(rdd, rdd)) -# # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) -#} -# @rdname cartesian -# @aliases cartesian,RDD,RDD-method +#' Cartesian product of this RDD and another one. +#' +#' Return the Cartesian product of this RDD and another one, +#' that is, the RDD of all pairs of elements (a, b) where a +#' is in this and b is in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @return A new RDD which is the Cartesian product of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2) +#' sortByKey(cartesian(rdd, rdd)) +#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#'} +#' @rdname cartesian +#' @aliases cartesian,RDD,RDD-method +#' @noRd setMethod("cartesian", signature(x = "RDD", other = "RDD"), function(x, other) { @@ -1533,24 +1567,25 @@ setMethod("cartesian", mergePartitions(rdd, FALSE) }) -# Subtract an RDD with another RDD. -# -# Return an RDD with the elements from this that are not in other. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions Number of the partitions in the result RDD. -# @return An RDD with the elements from this that are not in other. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) -# rdd2 <- parallelize(sc, list(2, 4)) -# collect(subtract(rdd1, rdd2)) -# # list(1, 1, 3) -#} -# @rdname subtract -# @aliases subtract,RDD +#' Subtract an RDD with another RDD. +#' +#' Return an RDD with the elements from this that are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the elements from this that are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +#' rdd2 <- parallelize(sc, list(2, 4)) +#' collect(subtract(rdd1, rdd2)) +#' # list(1, 1, 3) +#'} +#' @rdname subtract +#' @aliases subtract,RDD +#' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), function(x, other, numPartitions = SparkR:::numPartitions(x)) { @@ -1560,28 +1595,29 @@ setMethod("subtract", keys(subtractByKey(rdd1, rdd2, numPartitions)) }) -# Intersection of this RDD and another one. -# -# Return the intersection of this RDD and another one. -# The output will not contain any duplicate elements, -# even if the input RDDs did. Performs a hash partition -# across the cluster. -# Note that this method performs a shuffle internally. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions The number of partitions in the result RDD. -# @return An RDD which is the intersection of these two RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) -# rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) -# collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) -# # list(1, 2, 3) -#} -# @rdname intersection -# @aliases intersection,RDD +#' Intersection of this RDD and another one. +#' +#' Return the intersection of this RDD and another one. +#' The output will not contain any duplicate elements, +#' even if the input RDDs did. Performs a hash partition +#' across the cluster. +#' Note that this method performs a shuffle internally. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions The number of partitions in the result RDD. +#' @return An RDD which is the intersection of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' # list(1, 2, 3) +#'} +#' @rdname intersection +#' @aliases intersection,RDD +#' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), function(x, other, numPartitions = SparkR:::numPartitions(x)) { @@ -1597,26 +1633,27 @@ setMethod("intersection", keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) -# Zips an RDD's partitions with one (or more) RDD(s). -# Same as zipPartitions in Spark. -# -# @param ... RDDs to be zipped. -# @param func A function to transform zipped partitions. -# @return A new RDD by applying a function to the zipped partitions. -# Assumes that all the RDDs have the *same number of partitions*, but -# does *not* require them to have the same number of elements in each partition. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 -# rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 -# rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -# collect(zipPartitions(rdd1, rdd2, rdd3, -# func = function(x, y, z) { list(list(x, y, z))} )) -# # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) -#} -# @rdname zipRDD -# @aliases zipPartitions,RDD +#' Zips an RDD's partitions with one (or more) RDD(s). +#' Same as zipPartitions in Spark. +#' +#' @param ... RDDs to be zipped. +#' @param func A function to transform zipped partitions. +#' @return A new RDD by applying a function to the zipped partitions. +#' Assumes that all the RDDs have the *same number of partitions*, but +#' does *not* require them to have the same number of elements in each partition. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' func = function(x, y, z) { list(list(x, y, z))} )) +#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#'} +#' @rdname zipRDD +#' @aliases zipPartitions,RDD +#' @noRd setMethod("zipPartitions", "RDD", function(..., func) { diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 1bf025cce4376..a62b25fde926d 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -17,27 +17,33 @@ # SQLcontext.R: SQLContext-driven functions + +# Map top level R type to SQL type +getInternalType <- function(x) { + # class of POSIXlt is c("POSIXlt" "POSIXt") + switch(class(x)[[1]], + integer = "integer", + character = "string", + logical = "boolean", + double = "double", + numeric = "double", + raw = "binary", + list = "array", + struct = "struct", + environment = "map", + Date = "date", + POSIXlt = "timestamp", + POSIXct = "timestamp", + stop(paste("Unsupported type for DataFrame:", class(x)))) +} + #' infer the SQL type infer_type <- function(x) { if (is.null(x)) { stop("can not infer type from NULL") } - # class of POSIXlt is c("POSIXlt" "POSIXt") - type <- switch(class(x)[[1]], - integer = "integer", - character = "string", - logical = "boolean", - double = "double", - numeric = "double", - raw = "binary", - list = "array", - struct = "struct", - environment = "map", - Date = "date", - POSIXlt = "timestamp", - POSIXct = "timestamp", - stop(paste("Unsupported type for DataFrame:", class(x)))) + type <- getInternalType(x) if (type == "map") { stopifnot(length(x) > 0) @@ -90,19 +96,25 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 if (is.null(schema)) { schema <- names(data) } - n <- nrow(data) - m <- ncol(data) + # get rid of factor type - dropFactor <- function(x) { + cleanCols <- function(x) { if (is.factor(x)) { as.character(x) } else { x } } - data <- lapply(1:n, function(i) { - lapply(1:m, function(j) { dropFactor(data[i,j]) }) - }) + + # drop factors and wrap lists + data <- setNames(lapply(data, cleanCols), NULL) + + # check if all columns have supported type + lapply(data, getInternalType) + + # convert to rows + args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE) + data <- do.call(mapply, append(args, data)) } if (is.list(data)) { sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) @@ -144,7 +156,6 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 } stopifnot(class(schema) == "structType") - # schemaString <- tojson(schema) jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") @@ -160,22 +171,21 @@ as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { createDataFrame(sqlContext, data, schema, samplingRatio) } -# toDF -# -# Converts an RDD to a DataFrame by infer the types. -# -# @param x An RDD -# -# @rdname DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -# df <- toDF(rdd) -# } - +#' toDF +#' +#' Converts an RDD to a DataFrame by infer the types. +#' +#' @param x An RDD +#' +#' @rdname DataFrame +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) +#' df <- toDF(rdd) +#'} setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), @@ -217,23 +227,23 @@ jsonFile <- function(sqlContext, path) { } -# JSON RDD -# -# Loads an RDD storing one JSON object per string as a DataFrame. -# -# @param sqlContext SQLContext to use -# @param rdd An RDD of JSON string -# @param schema A StructType object to use as schema -# @param samplingRatio The ratio of simpling used to infer the schema -# @return A DataFrame -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# sqlContext <- sparkRSQL.init(sc) -# rdd <- texFile(sc, "path/to/json") -# df <- jsonRDD(sqlContext, rdd) -# } +#' JSON RDD +#' +#' Loads an RDD storing one JSON object per string as a DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param rdd An RDD of JSON string +#' @param schema A StructType object to use as schema +#' @param samplingRatio The ratio of simpling used to infer the schema +#' @return A DataFrame +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' rdd <- texFile(sc, "path/to/json") +#' df <- jsonRDD(sqlContext, rdd) +#'} # TODO: support schema jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 720990e1c6087..471bec1eacf03 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -25,23 +25,23 @@ getMinPartitions <- function(sc, minPartitions) { as.integer(minPartitions) } -# Create an RDD from a text file. -# -# This function reads a text file from HDFS, a local file system (available on all -# nodes), or any Hadoop-supported file system URI, and creates an -# RDD of strings from it. -# -# @param sc SparkContext to use -# @param path Path of file to read. A vector of multiple paths is allowed. -# @param minPartitions Minimum number of partitions to be created. If NULL, the default -# value is chosen based on available parallelism. -# @return RDD where each item is of type \code{character} -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# lines <- textFile(sc, "myfile.txt") -#} +#' Create an RDD from a text file. +#' +#' This function reads a text file from HDFS, a local file system (available on all +#' nodes), or any Hadoop-supported file system URI, and creates an +#' RDD of strings from it. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD where each item is of type \code{character} +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' lines <- textFile(sc, "myfile.txt") +#'} textFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) @@ -53,23 +53,23 @@ textFile <- function(sc, path, minPartitions = NULL) { RDD(jrdd, "string") } -# Load an RDD saved as a SequenceFile containing serialized objects. -# -# The file to be loaded should be one that was previously generated by calling -# saveAsObjectFile() of the RDD class. -# -# @param sc SparkContext to use -# @param path Path of file to read. A vector of multiple paths is allowed. -# @param minPartitions Minimum number of partitions to be created. If NULL, the default -# value is chosen based on available parallelism. -# @return RDD containing serialized R objects. -# @seealso saveAsObjectFile -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- objectFile(sc, "myfile") -#} +#' Load an RDD saved as a SequenceFile containing serialized objects. +#' +#' The file to be loaded should be one that was previously generated by calling +#' saveAsObjectFile() of the RDD class. +#' +#' @param sc SparkContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default +#' value is chosen based on available parallelism. +#' @return RDD containing serialized R objects. +#' @seealso saveAsObjectFile +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- objectFile(sc, "myfile") +#'} objectFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) @@ -81,24 +81,24 @@ objectFile <- function(sc, path, minPartitions = NULL) { RDD(jrdd, "byte") } -# Create an RDD from a homogeneous list or vector. -# -# This function creates an RDD from a local homogeneous list in R. The elements -# in the list are split into \code{numSlices} slices and distributed to nodes -# in the cluster. -# -# @param sc SparkContext to use -# @param coll collection to parallelize -# @param numSlices number of partitions to create in the RDD -# @return an RDD created from this collection -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10, 2) -# # The RDD should contain 10 elements -# length(rdd) -#} +#' Create an RDD from a homogeneous list or vector. +#' +#' This function creates an RDD from a local homogeneous list in R. The elements +#' in the list are split into \code{numSlices} slices and distributed to nodes +#' in the cluster. +#' +#' @param sc SparkContext to use +#' @param coll collection to parallelize +#' @param numSlices number of partitions to create in the RDD +#' @return an RDD created from this collection +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10, 2) +#' # The RDD should contain 10 elements +#' length(rdd) +#'} parallelize <- function(sc, coll, numSlices = 1) { # TODO: bound/safeguard numSlices # TODO: unit tests for if the split works for all primitives @@ -133,33 +133,32 @@ parallelize <- function(sc, coll, numSlices = 1) { RDD(jrdd, "byte") } -# Include this specified package on all workers -# -# This function can be used to include a package on all workers before the -# user's code is executed. This is useful in scenarios where other R package -# functions are used in a function passed to functions like \code{lapply}. -# NOTE: The package is assumed to be installed on every node in the Spark -# cluster. -# -# @param sc SparkContext to use -# @param pkg Package name -# -# @export -# @examples -#\dontrun{ -# library(Matrix) -# -# sc <- sparkR.init() -# # Include the matrix library we will be using -# includePackage(sc, Matrix) -# -# generateSparse <- function(x) { -# sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) -# } -# -# rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) -# collect(rdd) -#} +#' Include this specified package on all workers +#' +#' This function can be used to include a package on all workers before the +#' user's code is executed. This is useful in scenarios where other R package +#' functions are used in a function passed to functions like \code{lapply}. +#' NOTE: The package is assumed to be installed on every node in the Spark +#' cluster. +#' +#' @param sc SparkContext to use +#' @param pkg Package name +#' @noRd +#' @examples +#'\dontrun{ +#' library(Matrix) +#' +#' sc <- sparkR.init() +#' # Include the matrix library we will be using +#' includePackage(sc, Matrix) +#' +#' generateSparse <- function(x) { +#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3)) +#' } +#' +#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse) +#' collect(rdd) +#'} includePackage <- function(sc, pkg) { pkg <- as.character(substitute(pkg)) if (exists(".packages", .sparkREnv)) { @@ -171,30 +170,30 @@ includePackage <- function(sc, pkg) { .sparkREnv$.packages <- packages } -# @title Broadcast a variable to all workers -# -# @description -# Broadcast a read-only variable to the cluster, returning a \code{Broadcast} -# object for reading it in distributed functions. -# -# @param sc Spark Context to use -# @param object Object to be broadcast -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:2, 2L) -# -# # Large Matrix object that we want to broadcast -# randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -# randomMatBr <- broadcast(sc, randomMat) -# -# # Use the broadcast variable inside the function -# useBroadcast <- function(x) { -# sum(value(randomMatBr) * x) -# } -# sumRDD <- lapply(rdd, useBroadcast) -#} +#' @title Broadcast a variable to all workers +#' +#' @description +#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} +#' object for reading it in distributed functions. +#' +#' @param sc Spark Context to use +#' @param object Object to be broadcast +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2, 2L) +#' +#' # Large Matrix object that we want to broadcast +#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) +#' randomMatBr <- broadcast(sc, randomMat) +#' +#' # Use the broadcast variable inside the function +#' useBroadcast <- function(x) { +#' sum(value(randomMatBr) * x) +#' } +#' sumRDD <- lapply(rdd, useBroadcast) +#'} broadcast <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) @@ -205,21 +204,21 @@ broadcast <- function(sc, object) { Broadcast(id, object, jBroadcast, objName) } -# @title Set the checkpoint directory -# -# Set the directory under which RDDs are going to be checkpointed. The -# directory must be a HDFS path if running on a cluster. -# -# @param sc Spark Context to use -# @param dirName Directory path -# @export -# @examples -#\dontrun{ -# sc <- sparkR.init() -# setCheckpointDir(sc, "~/checkpoint") -# rdd <- parallelize(sc, 1:2, 2L) -# checkpoint(rdd) -#} +#' @title Set the checkpoint directory +#' +#' Set the directory under which RDDs are going to be checkpointed. The +#' directory must be a HDFS path if running on a cluster. +#' +#' @param sc Spark Context to use +#' @param dirName Directory path +#' @noRd +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setCheckpointDir(sc, "~/checkpoint") +#' rdd <- parallelize(sc, 1:2, 2L) +#' checkpoint(rdd) +#'} setCheckpointDir <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d7fd279279137..3d0255a62f155 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -520,6 +520,22 @@ setMethod("isNaN", column(jc) }) +#' kurtosis +#' +#' Aggregate function: returns the kurtosis of the values in a group. +#' +#' @rdname kurtosis +#' @name kurtosis +#' @family agg_funcs +#' @export +#' @examples \dontrun{kurtosis(df$c)} +setMethod("kurtosis", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "kurtosis", x@jc) + column(jc) + }) + #' last #' #' Aggregate function: returns the last value in a group. @@ -861,6 +877,28 @@ setMethod("rtrim", column(jc) }) +#' sd +#' +#' Aggregate function: alias for \link{stddev_samp} +#' +#' @rdname sd +#' @name sd +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{stddev_samp} +#' @export +#' @examples +#'\dontrun{ +#'stddev(df$c) +#'select(df, stddev(df$age)) +#'agg(df, sd(df$age)) +#'} +setMethod("sd", + signature(x = "Column"), + function(x, na.rm = FALSE) { + # In R, sample standard deviation is calculated with the sd() function. + stddev_samp(x) + }) + #' second #' #' Extracts the seconds as an integer from a given date/timestamp/string. @@ -958,6 +996,22 @@ setMethod("size", column(jc) }) +#' skewness +#' +#' Aggregate function: returns the skewness of the values in a group. +#' +#' @rdname skewness +#' @name skewness +#' @family agg_funcs +#' @export +#' @examples \dontrun{skewness(df$c)} +setMethod("skewness", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "skewness", x@jc) + column(jc) + }) + #' soundex #' #' Return the soundex code for the specified expression. @@ -974,6 +1028,49 @@ setMethod("soundex", column(jc) }) +#' @rdname sd +#' @name stddev +setMethod("stddev", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev", x@jc) + column(jc) + }) + +#' stddev_pop +#' +#' Aggregate function: returns the population standard deviation of the expression in a group. +#' +#' @rdname stddev_pop +#' @name stddev_pop +#' @family agg_funcs +#' @seealso \link{sd}, \link{stddev_samp} +#' @export +#' @examples \dontrun{stddev_pop(df$c)} +setMethod("stddev_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_pop", x@jc) + column(jc) + }) + +#' stddev_samp +#' +#' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. +#' +#' @rdname stddev_samp +#' @name stddev_samp +#' @family agg_funcs +#' @seealso \link{stddev_pop}, \link{sd} +#' @export +#' @examples \dontrun{stddev_samp(df$c)} +setMethod("stddev_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "stddev_samp", x@jc) + column(jc) + }) + #' sqrt #' #' Computes the square root of the specified float value. @@ -1168,6 +1265,71 @@ setMethod("upper", column(jc) }) +#' var +#' +#' Aggregate function: alias for \link{var_samp}. +#' +#' @rdname var +#' @name var +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var_samp} +#' @export +#' @examples +#'\dontrun{ +#'variance(df$c) +#'select(df, var_pop(df$age)) +#'agg(df, var(df$age)) +#'} +setMethod("var", + signature(x = "Column"), + function(x, y = NULL, na.rm = FALSE, use) { + # In R, sample variance is calculated with the var() function. + var_samp(x) + }) + +#' @rdname var +#' @name variance +setMethod("variance", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "variance", x@jc) + column(jc) + }) + +#' var_pop +#' +#' Aggregate function: returns the population variance of the values in a group. +#' +#' @rdname var_pop +#' @name var_pop +#' @family agg_funcs +#' @seealso \link{var}, \link{var_samp} +#' @export +#' @examples \dontrun{var_pop(df$c)} +setMethod("var_pop", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_pop", x@jc) + column(jc) + }) + +#' var_samp +#' +#' Aggregate function: returns the unbiased variance of the values in a group. +#' +#' @rdname var_samp +#' @name var_samp +#' @family agg_funcs +#' @seealso \link{var_pop}, \link{var} +#' @export +#' @examples \dontrun{var_samp(df$c)} +setMethod("var_samp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "var_samp", x@jc) + column(jc) + }) + #' weekofyear #' #' Extracts the week number as an integer from a given date/timestamp/string. @@ -1339,7 +1501,7 @@ setMethod("pmod", signature(y = "Column"), #' @export setMethod("approxCountDistinct", signature(x = "Column"), - function(x, rsd = 0.95) { + function(x, rsd = 0.05) { jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) column(jc) }) @@ -2020,10 +2182,10 @@ setMethod("ifelse", #' #' Window function: returns the cumulative distribution of values within a window partition, #' i.e. the fraction of rows that are below the current row. -#' +#' #' N = total number of rows in the partition #' cumeDist(x) = number of values before (and including) x / N -#' +#' #' This is equivalent to the CUME_DIST function in SQL. #' #' @rdname cumeDist @@ -2039,13 +2201,13 @@ setMethod("cumeDist", }) #' denseRank -#' +#' #' Window function: returns the rank of rows within a window partition, without any gaps. #' The difference between rank and denseRank is that denseRank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using denseRank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. -#' +#' #' This is equivalent to the DENSE_RANK function in SQL. #' #' @rdname denseRank @@ -2065,7 +2227,7 @@ setMethod("denseRank", #' Window function: returns the value that is `offset` rows before the current row, and #' `defaultValue` if there is less than `offset` rows before the current row. For example, #' an `offset` of one will return the previous row at any given point in the window partition. -#' +#' #' This is equivalent to the LAG function in SQL. #' #' @rdname lag @@ -2092,7 +2254,7 @@ setMethod("lag", #' Window function: returns the value that is `offset` rows after the current row, and #' `null` if there is less than `offset` rows after the current row. For example, #' an `offset` of one will return the next row at any given point in the window partition. -#' +#' #' This is equivalent to the LEAD function in SQL. #' #' @rdname lead @@ -2119,7 +2281,7 @@ setMethod("lead", #' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window #' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. -#' +#' #' This is equivalent to the NTILE function in SQL. #' #' @rdname ntile @@ -2137,9 +2299,9 @@ setMethod("ntile", #' percentRank #' #' Window function: returns the relative rank (i.e. percentile) of rows within a window partition. -#' +#' #' This is computed by: -#' +#' #' (rank of row in its partition - 1) / (number of rows in the partition - 1) #' #' This is equivalent to the PERCENT_RANK function in SQL. @@ -2159,12 +2321,12 @@ setMethod("percentRank", #' rank #' #' Window function: returns the rank of rows within a window partition. -#' +#' #' The difference between rank and denseRank is that denseRank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using denseRank #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. -#' +#' #' This is equivalent to the RANK function in SQL. #' #' @rdname rank @@ -2189,7 +2351,7 @@ setMethod("rank", #' rowNumber #' #' Window function: returns a sequential number starting at 1 within a window partition. -#' +#' #' This is equivalent to the ROW_NUMBER function in SQL. #' #' @rdname rowNumber diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 083d37fee28a4..612e639f8ad99 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -88,12 +88,8 @@ setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") }) # @export setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") }) -# @rdname foreach -# @export setGeneric("foreach", function(x, func) { standardGeneric("foreach") }) -# @rdname foreach -# @export setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") }) # The jrdd accessor function. @@ -107,27 +103,17 @@ setGeneric("glom", function(x) { standardGeneric("glom") }) # @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) -# @rdname lapplyPartition -# @export setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") }) -# @rdname lapplyPartitionsWithIndex -# @export setGeneric("lapplyPartitionsWithIndex", function(X, FUN) { standardGeneric("lapplyPartitionsWithIndex") }) -# @rdname lapply -# @export setGeneric("map", function(X, FUN) { standardGeneric("map") }) -# @rdname lapplyPartition -# @export setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") }) -# @rdname lapplyPartitionsWithIndex -# @export setGeneric("mapPartitionsWithIndex", function(X, FUN) { standardGeneric("mapPartitionsWithIndex") }) @@ -561,14 +547,10 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) #' @rdname summary #' @export -setGeneric("summary", function(x, ...) { standardGeneric("summary") }) +setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -# @rdname tojson -# @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) -#' @rdname DataFrame -#' @export setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) #' @rdname unionAll @@ -798,6 +780,10 @@ setGeneric("instr", function(y, x) { standardGeneric("instr") }) #' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) +#' @rdname kurtosis +#' @export +setGeneric("kurtosis", function(x) { standardGeneric("kurtosis") }) + #' @rdname lag #' @export setGeneric("lag", function(x, offset, defaultValue = NULL) { standardGeneric("lag") }) @@ -935,6 +921,10 @@ setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) #' @export setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) +#' @rdname sd +#' @export +setGeneric("sd", function(x, na.rm = FALSE) { standardGeneric("sd") }) + #' @rdname second #' @export setGeneric("second", function(x) { standardGeneric("second") }) @@ -967,10 +957,26 @@ setGeneric("signum", function(x) { standardGeneric("signum") }) #' @export setGeneric("size", function(x) { standardGeneric("size") }) +#' @rdname skewness +#' @export +setGeneric("skewness", function(x) { standardGeneric("skewness") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) +#' @rdname sd +#' @export +setGeneric("stddev", function(x) { standardGeneric("stddev") }) + +#' @rdname stddev_pop +#' @export +setGeneric("stddev_pop", function(x) { standardGeneric("stddev_pop") }) + +#' @rdname stddev_samp +#' @export +setGeneric("stddev_samp", function(x) { standardGeneric("stddev_samp") }) + #' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) @@ -1019,6 +1025,22 @@ setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timesta #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) +#' @rdname var +#' @export +setGeneric("var", function(x, y = NULL, na.rm = FALSE, use) { standardGeneric("var") }) + +#' @rdname var +#' @export +setGeneric("variance", function(x) { standardGeneric("variance") }) + +#' @rdname var_pop +#' @export +setGeneric("var_pop", function(x) { standardGeneric("var_pop") }) + +#' @rdname var_samp +#' @export +setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) + #' @rdname weekofyear #' @export setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) @@ -1047,3 +1069,7 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 4cab1a69f601a..e5f702faee65d 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -79,6 +79,7 @@ setMethod("count", #' @param x a GroupedData #' @return a DataFrame #' @rdname agg +#' @family agg_funcs #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' @@ -117,8 +118,11 @@ setMethod("summarize", agg(x, ...) }) -# sum/mean/avg/min/max -methods <- c("sum", "mean", "avg", "min", "max") +# Aggregate Functions by name +methods <- c("avg", "max", "mean", "min", "sum") + +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", +# "variance", "var_samp", "var_pop" createMethod <- function(name) { setMethod(name, diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b0d73dd93a79d..f23e1c7f1fce4 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -89,14 +89,28 @@ setMethod("predict", signature(object = "PipelineModel"), #' model <- glm(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(x = "PipelineModel"), - function(x, ...) { +setMethod("summary", signature(object = "PipelineModel"), + function(object, ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", x@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", x@model) - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) + "getModelCoefficients", object@model) + if (modelName == "LinearRegressionModel") { + devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelDevianceResiduals", object@model) + devianceResiduals <- matrix(devianceResiduals, nrow = 1) + colnames(devianceResiduals) <- c("Min", "Max") + rownames(devianceResiduals) <- rep("", times = 1) + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) + } else { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + } }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 199c3fd6ab1b2..991bea4d2022d 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -21,23 +21,24 @@ NULL ############ Actions and Transformations ############ -# Look up elements of a key in an RDD -# -# @description -# \code{lookup} returns a list of values in this RDD for key key. -# -# @param x The RDD to collect -# @param key The key to look up for -# @return a list of values in this RDD for key key -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(c(1, 1), c(2, 2), c(1, 3)) -# rdd <- parallelize(sc, pairs) -# lookup(rdd, 1) # list(1, 3) -#} -# @rdname lookup -# @aliases lookup,RDD-method +#' Look up elements of a key in an RDD +#' +#' @description +#' \code{lookup} returns a list of values in this RDD for key key. +#' +#' @param x The RDD to collect +#' @param key The key to look up for +#' @return a list of values in this RDD for key key +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) +#' rdd <- parallelize(sc, pairs) +#' lookup(rdd, 1) # list(1, 3) +#'} +#' @rdname lookup +#' @aliases lookup,RDD-method +#' @noRd setMethod("lookup", signature(x = "RDD", key = "ANY"), function(x, key) { @@ -49,21 +50,22 @@ setMethod("lookup", collect(valsRDD) }) -# Count the number of elements for each key, and return the result to the -# master as lists of (key, count) pairs. -# -# Same as countByKey in Spark. -# -# @param x The RDD to count keys. -# @return list of (key, count) pairs, where count is number of each key in rdd. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) -# countByKey(rdd) # ("a", 2L), ("b", 1L) -#} -# @rdname countByKey -# @aliases countByKey,RDD-method +#' Count the number of elements for each key, and return the result to the +#' master as lists of (key, count) pairs. +#' +#' Same as countByKey in Spark. +#' +#' @param x The RDD to count keys. +#' @return list of (key, count) pairs, where count is number of each key in rdd. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) +#' countByKey(rdd) # ("a", 2L), ("b", 1L) +#'} +#' @rdname countByKey +#' @aliases countByKey,RDD-method +#' @noRd setMethod("countByKey", signature(x = "RDD"), function(x) { @@ -71,17 +73,18 @@ setMethod("countByKey", countByValue(keys) }) -# Return an RDD with the keys of each tuple. -# -# @param x The RDD from which the keys of each tuple is returned. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -# collect(keys(rdd)) # list(1, 3) -#} -# @rdname keys -# @aliases keys,RDD +#' Return an RDD with the keys of each tuple. +#' +#' @param x The RDD from which the keys of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(keys(rdd)) # list(1, 3) +#'} +#' @rdname keys +#' @aliases keys,RDD +#' @noRd setMethod("keys", signature(x = "RDD"), function(x) { @@ -91,17 +94,18 @@ setMethod("keys", lapply(x, func) }) -# Return an RDD with the values of each tuple. -# -# @param x The RDD from which the values of each tuple is returned. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -# collect(values(rdd)) # list(2, 4) -#} -# @rdname values -# @aliases values,RDD +#' Return an RDD with the values of each tuple. +#' +#' @param x The RDD from which the values of each tuple is returned. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) +#' collect(values(rdd)) # list(2, 4) +#'} +#' @rdname values +#' @aliases values,RDD +#' @noRd setMethod("values", signature(x = "RDD"), function(x) { @@ -111,23 +115,24 @@ setMethod("values", lapply(x, func) }) -# Applies a function to all values of the elements, without modifying the keys. -# -# The same as `mapValues()' in Spark. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on the value of each element. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:10) -# makePairs <- lapply(rdd, function(x) { list(x, x) }) -# collect(mapValues(makePairs, function(x) { x * 2) }) -# Output: list(list(1,2), list(2,4), list(3,6), ...) -#} -# @rdname mapValues -# @aliases mapValues,RDD,function-method +#' Applies a function to all values of the elements, without modifying the keys. +#' +#' The same as `mapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:10) +#' makePairs <- lapply(rdd, function(x) { list(x, x) }) +#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' Output: list(list(1,2), list(2,4), list(3,6), ...) +#'} +#' @rdname mapValues +#' @aliases mapValues,RDD,function-method +#' @noRd setMethod("mapValues", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -137,23 +142,24 @@ setMethod("mapValues", lapply(X, func) }) -# Pass each value in the key-value pair RDD through a flatMap function without -# changing the keys; this also retains the original RDD's partitioning. -# -# The same as 'flatMapValues()' in Spark. -# -# @param X The RDD to apply the transformation. -# @param FUN the transformation to apply on the value of each element. -# @return a new RDD created by the transformation. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) -# collect(flatMapValues(rdd, function(x) { x })) -# Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) -#} -# @rdname flatMapValues -# @aliases flatMapValues,RDD,function-method +#' Pass each value in the key-value pair RDD through a flatMap function without +#' changing the keys; this also retains the original RDD's partitioning. +#' +#' The same as 'flatMapValues()' in Spark. +#' +#' @param X The RDD to apply the transformation. +#' @param FUN the transformation to apply on the value of each element. +#' @return a new RDD created by the transformation. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) +#' collect(flatMapValues(rdd, function(x) { x })) +#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) +#'} +#' @rdname flatMapValues +#' @aliases flatMapValues,RDD,function-method +#' @noRd setMethod("flatMapValues", signature(X = "RDD", FUN = "function"), function(X, FUN) { @@ -165,38 +171,34 @@ setMethod("flatMapValues", ############ Shuffle Functions ############ -# Partition an RDD by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# For each element of this RDD, the partitioner is used to compute a hash -# function and the RDD is partitioned using this hash value. -# -# @param x The RDD to partition. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param numPartitions Number of partitions to create. -# @param ... Other optional arguments to partitionBy. -# -# @param partitionFunc The partition function to use. Uses a default hashCode -# function if not provided -# @return An RDD partitioned using the specified partitioner. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- partitionBy(rdd, 2L) -# collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) -#} -# @rdname partitionBy -# @aliases partitionBy,RDD,integer-method +#' Partition an RDD by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' For each element of this RDD, the partitioner is used to compute a hash +#' function and the RDD is partitioned using this hash value. +#' +#' @param x The RDD to partition. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @param ... Other optional arguments to partitionBy. +#' +#' @param partitionFunc The partition function to use. Uses a default hashCode +#' function if not provided +#' @return An RDD partitioned using the specified partitioner. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- partitionBy(rdd, 2L) +#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) +#'} +#' @rdname partitionBy +#' @aliases partitionBy,RDD,integer-method +#' @noRd setMethod("partitionBy", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, partitionFunc = hashCode) { - - #if (missing(partitionFunc)) { - # partitionFunc <- hashCode - #} - partitionFunc <- cleanClosure(partitionFunc) serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) @@ -233,27 +235,28 @@ setMethod("partitionBy", RDD(r, serializedMode = "byte") }) -# Group values by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and group values for each key in the RDD into a single sequence. -# -# @param x The RDD to group. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, list(V)) -# @seealso reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- groupByKey(rdd, 2L) -# grouped <- collect(parts) -# grouped[[1]] # Should be a list(1, list(2, 4)) -#} -# @rdname groupByKey -# @aliases groupByKey,RDD,integer-method +#' Group values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and group values for each key in the RDD into a single sequence. +#' +#' @param x The RDD to group. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, list(V)) +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- groupByKey(rdd, 2L) +#' grouped <- collect(parts) +#' grouped[[1]] # Should be a list(1, list(2, 4)) +#'} +#' @rdname groupByKey +#' @aliases groupByKey,RDD,integer-method +#' @noRd setMethod("groupByKey", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { @@ -291,28 +294,29 @@ setMethod("groupByKey", lapplyPartition(shuffled, groupVals) }) -# Merge values by key -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and merges the values for each key using an associative reduce function. -# -# @param x The RDD to reduce by key. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param combineFunc The associative reduce function to use. -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, V') where V' is the merged -# value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- reduceByKey(rdd, "+", 2L) -# reduced <- collect(parts) -# reduced[[1]] # Should be a list(1, 6) -#} -# @rdname reduceByKey -# @aliases reduceByKey,RDD,integer-method +#' Merge values by key +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, V') where V' is the merged +#' value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- reduceByKey(rdd, "+", 2L) +#' reduced <- collect(parts) +#' reduced[[1]] # Should be a list(1, 6) +#'} +#' @rdname reduceByKey +#' @aliases reduceByKey,RDD,integer-method +#' @noRd setMethod("reduceByKey", signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), function(x, combineFunc, numPartitions) { @@ -332,27 +336,28 @@ setMethod("reduceByKey", lapplyPartition(shuffled, reduceVals) }) -# Merge values by key locally -# -# This function operates on RDDs where every element is of the form list(K, V) or c(K, V). -# and merges the values for each key using an associative reduce function, but return the -# results immediately to the driver as an R list. -# -# @param x The RDD to reduce by key. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param combineFunc The associative reduce function to use. -# @return A list of elements of type list(K, V') where V' is the merged value for each key -# @seealso reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# reduced <- reduceByKeyLocally(rdd, "+") -# reduced # list(list(1, 6), list(1.1, 3)) -#} -# @rdname reduceByKeyLocally -# @aliases reduceByKeyLocally,RDD,integer-method +#' Merge values by key locally +#' +#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V). +#' and merges the values for each key using an associative reduce function, but return the +#' results immediately to the driver as an R list. +#' +#' @param x The RDD to reduce by key. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param combineFunc The associative reduce function to use. +#' @return A list of elements of type list(K, V') where V' is the merged value for each key +#' @seealso reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' reduced <- reduceByKeyLocally(rdd, "+") +#' reduced # list(list(1, 6), list(1.1, 3)) +#'} +#' @rdname reduceByKeyLocally +#' @aliases reduceByKeyLocally,RDD,integer-method +#' @noRd setMethod("reduceByKeyLocally", signature(x = "RDD", combineFunc = "ANY"), function(x, combineFunc) { @@ -384,41 +389,40 @@ setMethod("reduceByKeyLocally", convertEnvsToList(merged[[1]], merged[[2]]) }) -# Combine values by key -# -# Generic function to combine the elements for each key using a custom set of -# aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], -# for a "combined type" C. Note that V and C can be different -- for example, one -# might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). - -# Users provide three functions: -# \itemize{ -# \item createCombiner, which turns a V into a C (e.g., creates a one-element list) -# \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - -# \item mergeCombiners, to combine two C's into a single one (e.g., concatentates -# two lists). -# } -# -# @param x The RDD to combine. Should be an RDD where each element is -# list(K, V) or c(K, V). -# @param createCombiner Create a combiner (C) given a value (V) -# @param mergeValue Merge the given value (V) with an existing combiner (C) -# @param mergeCombiners Merge two combiners and return a new combiner -# @param numPartitions Number of partitions to create. -# @return An RDD where each element is list(K, C) where C is the combined type -# -# @seealso groupByKey, reduceByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) -# rdd <- parallelize(sc, pairs) -# parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) -# combined <- collect(parts) -# combined[[1]] # Should be a list(1, 6) -#} -# @rdname combineByKey -# @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +#' Combine values by key +#' +#' Generic function to combine the elements for each key using a custom set of +#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], +#' for a "combined type" C. Note that V and C can be different -- for example, one +#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]). +#' Users provide three functions: +#' \itemize{ +#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list) +#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) - +#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates +#' two lists). +#' } +#' +#' @param x The RDD to combine. Should be an RDD where each element is +#' list(K, V) or c(K, V). +#' @param createCombiner Create a combiner (C) given a value (V) +#' @param mergeValue Merge the given value (V) with an existing combiner (C) +#' @param mergeCombiners Merge two combiners and return a new combiner +#' @param numPartitions Number of partitions to create. +#' @return An RDD where each element is list(K, C) where C is the combined type +#' @seealso groupByKey, reduceByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) +#' rdd <- parallelize(sc, pairs) +#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) +#' combined <- collect(parts) +#' combined[[1]] # Should be a list(1, 6) +#'} +#' @rdname combineByKey +#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method +#' @noRd setMethod("combineByKey", signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", mergeCombiners = "ANY", numPartitions = "numeric"), @@ -450,36 +454,37 @@ setMethod("combineByKey", lapplyPartition(shuffled, mergeAfterShuffle) }) -# Aggregate a pair RDD by each key. -# -# Aggregate the values of each key in an RDD, using given combine functions -# and a neutral "zero value". This function can return a different result type, -# U, than the type of the values in this RDD, V. Thus, we need one operation -# for merging a V into a U and one operation for merging two U's, The former -# operation is used for merging values within a partition, and the latter is -# used for merging values between partitions. To avoid memory allocation, both -# of these functions are allowed to modify and return their first argument -# instead of creating a new U. -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param seqOp A function to aggregate the values of each key. It may return -# a different result type from the type of the values. -# @param combOp A function to aggregate results of seqOp. -# @return An RDD containing the aggregation result. -# @seealso foldByKey, combineByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) -# zeroValue <- list(0, 0) -# seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } -# combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -# aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) -# # list(list(1, list(3, 2)), list(2, list(7, 2))) -#} -# @rdname aggregateByKey -# @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +#' Aggregate a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using given combine functions +#' and a neutral "zero value". This function can return a different result type, +#' U, than the type of the values in this RDD, V. Thus, we need one operation +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument +#' instead of creating a new U. +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param seqOp A function to aggregate the values of each key. It may return +#' a different result type from the type of the values. +#' @param combOp A function to aggregate results of seqOp. +#' @return An RDD containing the aggregation result. +#' @seealso foldByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' zeroValue <- list(0, 0) +#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } +#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' # list(list(1, list(3, 2)), list(2, list(7, 2))) +#'} +#' @rdname aggregateByKey +#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method +#' @noRd setMethod("aggregateByKey", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY", numPartitions = "numeric"), @@ -491,26 +496,27 @@ setMethod("aggregateByKey", combineByKey(x, createCombiner, seqOp, combOp, numPartitions) }) -# Fold a pair RDD by each key. -# -# Aggregate the values of each key in an RDD, using an associative function "func" -# and a neutral "zero value" which may be added to the result an arbitrary -# number of times, and must not change the result (e.g., 0 for addition, or -# 1 for multiplication.). -# -# @param x An RDD. -# @param zeroValue A neutral "zero value". -# @param func An associative function for folding values of each key. -# @return An RDD containing the aggregation result. -# @seealso aggregateByKey, combineByKey -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) -# foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) -#} -# @rdname foldByKey -# @aliases foldByKey,RDD,ANY,ANY,integer-method +#' Fold a pair RDD by each key. +#' +#' Aggregate the values of each key in an RDD, using an associative function "func" +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or +#' 1 for multiplication.). +#' +#' @param x An RDD. +#' @param zeroValue A neutral "zero value". +#' @param func An associative function for folding values of each key. +#' @return An RDD containing the aggregation result. +#' @seealso aggregateByKey, combineByKey +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) +#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) +#'} +#' @rdname foldByKey +#' @aliases foldByKey,RDD,ANY,ANY,integer-method +#' @noRd setMethod("foldByKey", signature(x = "RDD", zeroValue = "ANY", func = "ANY", numPartitions = "numeric"), @@ -520,28 +526,29 @@ setMethod("foldByKey", ############ Binary Functions ############# -# Join two RDDs -# -# @description -# \code{join} This function joins two RDDs where every element is of the form list(K, V). -# The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return a new RDD containing all pairs of elements with matching keys in -# two input RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) -#} -# @rdname join-methods -# @aliases join,RDD,RDD-method +#' Join two RDDs +#' +#' @description +#' \code{join} This function joins two RDDs where every element is of the form list(K, V). +#' The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with matching keys in +#' two input RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#'} +#' @rdname join-methods +#' @aliases join,RDD,RDD-method +#' @noRd setMethod("join", signature(x = "RDD", y = "RDD"), function(x, y, numPartitions) { @@ -556,30 +563,31 @@ setMethod("join", doJoin) }) -# Left outer join two RDDs -# -# @description -# \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, v) in x, the resulting RDD will either contain -# all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) -# if no elements in rdd2 have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# leftOuterJoin(rdd1, rdd2, 2L) -# # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) -#} -# @rdname join-methods -# @aliases leftOuterJoin,RDD,RDD-method +#' Left outer join two RDDs +#' +#' @description +#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' if no elements in rdd2 have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' leftOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) +#'} +#' @rdname join-methods +#' @aliases leftOuterJoin,RDD,RDD-method +#' @noRd setMethod("leftOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -593,30 +601,31 @@ setMethod("leftOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# Right outer join two RDDs -# -# @description -# \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, w) in y, the resulting RDD will either contain -# all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) -# if no elements in x have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rightOuterJoin(rdd1, rdd2, 2L) -# # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) -#} -# @rdname join-methods -# @aliases rightOuterJoin,RDD,RDD-method +#' Right outer join two RDDs +#' +#' @description +#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, w) in y, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) +#' if no elements in x have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rightOuterJoin(rdd1, rdd2, 2L) +#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) +#'} +#' @rdname join-methods +#' @aliases rightOuterJoin,RDD,RDD-method +#' @noRd setMethod("rightOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -630,33 +639,34 @@ setMethod("rightOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# Full outer join two RDDs -# -# @description -# \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of -# the form list(K, V). The key types of the two RDDs should be the same. -# -# @param x An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param y An RDD to be joined. Should be an RDD where each element is -# list(K, V). -# @param numPartitions Number of partitions to create. -# @return For each element (k, v) in x and (k, w) in y, the resulting RDD -# will contain all pairs (k, (v, w)) for both (k, v) in x and -# (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements -# in x/y have key k. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) -# rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), -# # list(1, list(3, 1)), -# # list(2, list(NULL, 4))) -# # list(3, list(3, NULL)), -#} -# @rdname join-methods -# @aliases fullOuterJoin,RDD,RDD-method +#' Full outer join two RDDs +#' +#' @description +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of +#' the form list(K, V). The key types of the two RDDs should be the same. +#' +#' @param x An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param y An RDD to be joined. Should be an RDD where each element is +#' list(K, V). +#' @param numPartitions Number of partitions to create. +#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD +#' will contain all pairs (k, (v, w)) for both (k, v) in x and +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' in x/y have key k. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) +#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)), +#' # list(1, list(3, 1)), +#' # list(2, list(NULL, 4))) +#' # list(3, list(3, NULL)), +#'} +#' @rdname join-methods +#' @aliases fullOuterJoin,RDD,RDD-method +#' @noRd setMethod("fullOuterJoin", signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { @@ -670,23 +680,24 @@ setMethod("fullOuterJoin", joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) -# For each key k in several RDDs, return a resulting RDD that -# whose values are a list of values for the key in all RDDs. -# -# @param ... Several RDDs. -# @param numPartitions Number of partitions to create. -# @return a new RDD containing all pairs of elements with values in a list -# in all RDDs. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) -# rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -# cogroup(rdd1, rdd2, numPartitions = 2L) -# # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) -#} -# @rdname cogroup -# @aliases cogroup,RDD-method +#' For each key k in several RDDs, return a resulting RDD that +#' whose values are a list of values for the key in all RDDs. +#' +#' @param ... Several RDDs. +#' @param numPartitions Number of partitions to create. +#' @return a new RDD containing all pairs of elements with values in a list +#' in all RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) +#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) +#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) +#'} +#' @rdname cogroup +#' @aliases cogroup,RDD-method +#' @noRd setMethod("cogroup", "RDD", function(..., numPartitions) { @@ -722,20 +733,21 @@ setMethod("cogroup", group.func) }) -# Sort a (k, v) pair RDD by k. -# -# @param x A (k, v) pair RDD to be sorted. -# @param ascending A flag to indicate whether the sorting is ascending or descending. -# @param numPartitions Number of partitions to create. -# @return An RDD where all (k, v) pair elements are sorted. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) -# collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) -#} -# @rdname sortByKey -# @aliases sortByKey,RDD,RDD-method +#' Sort a (k, v) pair RDD by k. +#' +#' @param x A (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) +#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#'} +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +#' @noRd setMethod("sortByKey", signature(x = "RDD"), function(x, ascending = TRUE, numPartitions = SparkR:::numPartitions(x)) { @@ -784,25 +796,26 @@ setMethod("sortByKey", lapplyPartition(newRDD, partitionFunc) }) -# Subtract a pair RDD with another pair RDD. -# -# Return an RDD with the pairs from x whose keys are not in other. -# -# @param x An RDD. -# @param other An RDD. -# @param numPartitions Number of the partitions in the result RDD. -# @return An RDD with the pairs from x whose keys are not in other. -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), -# list("b", 5), list("a", 2))) -# rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) -# collect(subtractByKey(rdd1, rdd2)) -# # list(list("b", 4), list("b", 5)) -#} -# @rdname subtractByKey -# @aliases subtractByKey,RDD +#' Subtract a pair RDD with another pair RDD. +#' +#' Return an RDD with the pairs from x whose keys are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the pairs from x whose keys are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +#' list("b", 5), list("a", 2))) +#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +#' collect(subtractByKey(rdd1, rdd2)) +#' # list(list("b", 4), list("b", 5)) +#'} +#' @rdname subtractByKey +#' @aliases subtractByKey,RDD +#' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), function(x, other, numPartitions = SparkR:::numPartitions(x)) { @@ -818,41 +831,42 @@ setMethod("subtractByKey", function (v) { v[[1]] }) }) -# Return a subset of this RDD sampled by key. -# -# @description -# \code{sampleByKey} Create a sample of this RDD using variable sampling rates -# for different keys as specified by fractions, a key to sampling rate map. -# -# @param x The RDD to sample elements by key, where each element is -# list(K, V) or c(K, V). -# @param withReplacement Sampling with replacement or not -# @param fraction The (rough) sample target fraction -# @param seed Randomness seed value -# @examples -#\dontrun{ -# sc <- sparkR.init() -# rdd <- parallelize(sc, 1:3000) -# pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) -# else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) -# fractions <- list(a = 0.2, b = 0.1, c = 0.3) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) -# 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE -# 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE -# 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE -# lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE -# lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE -# lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE -# lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE -# lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE -# lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE -# fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored -# fractions <- list(a = 0.2, b = 0.1) -# sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" -#} -# @rdname sampleByKey -# @aliases sampleByKey,RDD-method +#' Return a subset of this RDD sampled by key. +#' +#' @description +#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates +#' for different keys as specified by fractions, a key to sampling rate map. +#' +#' @param x The RDD to sample elements by key, where each element is +#' list(K, V) or c(K, V). +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3000) +#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +#' fractions <- list(a = 0.2, b = 0.1) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#'} +#' @rdname sampleByKey +#' @aliases sampleByKey,RDD-method +#' @noRd setMethod("sampleByKey", signature(x = "RDD", withReplacement = "logical", fractions = "vector", seed = "integer"), diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 6f0e9a94e9bfa..c6ddb562270b7 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -115,20 +115,7 @@ structField.jobj <- function(x) { } checkType <- function(type) { - primtiveTypes <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - if (type %in% primtiveTypes) { + if (!is.null(PRIMITIVE_TYPES[[type]])) { return() } else { # Check complex types diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 004d08e74e1cd..7ff3fa628b9ca 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -34,7 +34,6 @@ connExists <- function(env) { sparkR.stop <- function() { env <- .sparkREnv if (exists(".sparkRCon", envir = env)) { - # cat("Stopping SparkR\n") if (exists(".sparkRjsc", envir = env)) { sc <- get(".sparkRjsc", envir = env) callJMethod(sc, "stop") @@ -49,6 +48,12 @@ sparkR.stop <- function() { } } + # Remove the R package lib path from .libPaths() + if (exists(".libPath", envir = env)) { + libPath <- get(".libPath", envir = env) + .libPaths(.libPaths()[.libPaths() != libPath]) + } + if (exists(".backendLaunched", envir = env)) { callJStatic("SparkRHandler", "stopBackend") } @@ -78,7 +83,7 @@ sparkR.stop <- function() { #' Initialize a new Spark Context. #' #' This function initializes a new SparkContext. For details on how to initialize -#' and use SparkR, refer to SparkR programming guide at +#' and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}. #' #' @param master The Spark master URL. @@ -156,14 +161,20 @@ sparkR.init <- function( f <- file(path, open="rb") backendPort <- readInt(f) monitorPort <- readInt(f) + rLibPath <- readString(f) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || - length(monitorPort) == 0 || monitorPort == 0) { + length(monitorPort) == 0 || monitorPort == 0 || + length(rLibPath) != 1) { stop("JVM failed to launch") } assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) + if (rLibPath != "") { + assign(".libPath", rLibPath, envir = .sparkREnv) + .libPaths(c(rLibPath, .libPaths())) + } } .sparkREnv$backendPort <- backendPort diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R new file mode 100644 index 0000000000000..1828c23ab0f6d --- /dev/null +++ b/R/pkg/R/types.R @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# types.R. This file handles the data type mapping between Spark and R + +# The primitive data types, where names(PRIMITIVE_TYPES) are Scala types whereas +# values are equivalent R types. This is stored in an environment to allow for +# more efficient look up (environments use hashmaps). +PRIMITIVE_TYPES <- as.environment(list( + "byte"="integer", + "tinyint"="integer", + "smallint"="integer", + "integer"="integer", + "bigint"="numeric", + "float"="numeric", + "double"="numeric", + "decimal"="numeric", + "string"="character", + "binary"="raw", + "boolean"="logical", + "timestamp"="POSIXct", + "date"="Date")) + +# The complex data types. These do not have any direct mapping to R's types. +COMPLEX_TYPES <- list( + "map"=NA, + "array"=NA, + "struct"=NA) + +# The full list of data types. +DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R index 2a8a8213d0849..c55fe9ba7af7a 100644 --- a/R/pkg/inst/profile/general.R +++ b/R/pkg/inst/profile/general.R @@ -17,6 +17,7 @@ .First <- function() { packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR") - .libPaths(c(packageDir, .libPaths())) + dirs <- strsplit(packageDir, ",")[[1]] + .libPaths(c(dirs, .libPaths())) Sys.setenv(NOAWT=1) } diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 7189f1a260934..90a3761e41f82 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -38,7 +38,7 @@ if (nchar(sparkVer) == 0) { cat("\n") } else { - cat(" version ", sparkVer, "\n") + cat(" version ", sparkVer, "\n") } cat(" /_/", "\n") cat("\n") diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 4761e285a2479..d497ad8c9daa3 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -71,12 +71,18 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) - coefs <- as.vector(stats$coefficients) - rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) + coefs <- unlist(stats$coefficients) + devianceResiduals <- unlist(stats$devianceResiduals) + + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + rCoefs <- unlist(rStats$coefficients) + rDevianceResiduals <- c(-0.95096, 0.72918) + + expect_true(all(abs(rCoefs - coefs) < 1e-5)) + expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) expect_true(all( - as.character(stats$features) == + rownames(stats$coefficients) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) @@ -85,7 +91,7 @@ test_that("summary coefficients match with native glm of family 'binomial'", { training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) - coefs <- as.vector(stats$coefficients) + coefs <- as.vector(stats$coefficients[,1]) rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, @@ -93,6 +99,12 @@ test_that("summary coefficients match with native glm of family 'binomial'", { expect_true(all(abs(rCoefs - coefs) < 1e-4)) expect_true(all( - as.character(stats$features) == + rownames(stats$coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) + +test_that("summary works on base GLM models", { + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 92cff1fba7193..8ff06276599e2 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -242,6 +242,14 @@ test_that("create DataFrame from list or data.frame", { expect_equal(count(df), 3) ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) + + irisdf <- createDataFrame(sqlContext, iris) + iris_collected <- collect(irisdf) + expect_equivalent(iris_collected[,-5], iris[,-5]) + expect_equal(iris_collected$Species, as.character(iris$Species)) + + mtcarsdf <- createDataFrame(sqlContext, mtcars) + expect_equivalent(collect(mtcarsdf), mtcars) }) test_that("create DataFrame with different data types", { @@ -283,6 +291,14 @@ test_that("create DataFrame with complex types", { expect_equal(s$b, 3L) }) +test_that("create DataFrame from a data.frame with complex types", { + ldf <- data.frame(row.names=1:2) + ldf$a_list <- list(list(1, 2), list(3, 4)) + sdf <- createDataFrame(sqlContext, ldf) + + expect_equivalent(ldf, collect(sdf)) +}) + # For test map type and struct type in DataFrame mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", @@ -647,11 +663,11 @@ test_that("sample on a DataFrame", { sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "DataFrame") - sampled2 <- sample(df, FALSE, 0.1) + sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled2) < 3) # Also test sample_frac - sampled3 <- sample_frac(df, FALSE, 0.1) + sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) }) @@ -826,12 +842,13 @@ test_that("column functions", { c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) - c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) + c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) - c12 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) - c13 <- cumeDist() + ntile(1) - c14 <- denseRank() + percentRank() + rank() + rowNumber() + c12 <- variance(c) + c13 <- lead("col", 1) + lead(c, 1) + lag("col", 1) + lag(c, 1) + c14 <- cumeDist() + ntile(1) + c15 <- denseRank() + percentRank() + rank() + rowNumber() # Test if base::rank() is exposed expect_equal(class(rank())[[1]], "Column") @@ -849,6 +866,12 @@ test_that("column functions", { expect_equal(collect(df3)[[2, 1]], FALSE) expect_equal(collect(df3)[[3, 1]], TRUE) + expect_equal(collect(select(df, sum(df$age)))[1, 1], 49) + + expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) + + expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) + df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") }) @@ -976,7 +999,7 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) }) -test_that("group by", { +test_that("group by, agg functions", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) @@ -997,20 +1020,64 @@ test_that("group by", { expect_is(df_summarized, "DataFrame") expect_equal(3, count(df_summarized)) - df3 <- agg(gd, age = "sum") - expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - - df3 <- agg(gd, age = sum(df$age)) + df3 <- agg(gd, age = "stddev") expect_is(df3, "DataFrame") - expect_equal(3, count(df3)) - expect_equal(columns(df3), c("name", "age")) + df3_local <- collect(df3) + expect_true(is.nan(df3_local[df3_local$name == "Andy",][1, 2])) - df4 <- sum(gd, "age") + df4 <- agg(gd, sumAge = sum(df$age)) expect_is(df4, "DataFrame") expect_equal(3, count(df4)) - expect_equal(3, count(mean(gd, "age"))) - expect_equal(3, count(max(gd, "age"))) + expect_equal(columns(df4), c("name", "sumAge")) + + df5 <- sum(gd, "age") + expect_is(df5, "DataFrame") + expect_equal(3, count(df5)) + + expect_equal(3, count(mean(gd))) + expect_equal(3, count(max(gd))) + expect_equal(30, collect(max(gd))[1, 2]) + expect_equal(1, collect(count(gd))[1, 2]) + + mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"10\"}", + "{\"name\":\"ID1\", \"value\": \"22\"}", + "{\"name\":\"ID2\", \"value\": \"-3\"}") + jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines2, jsonPath2) + gd2 <- groupBy(jsonFile(sqlContext, jsonPath2), "name") + df6 <- agg(gd2, value = "sum") + df6_local <- collect(df6) + expect_equal(42, df6_local[df6_local$name == "ID1",][1, 2]) + expect_equal(-3, df6_local[df6_local$name == "ID2",][1, 2]) + + df7 <- agg(gd2, value = "stddev") + df7_local <- collect(df7) + expect_true(abs(df7_local[df7_local$name == "ID1",][1, 2] - 6.928203) < 1e-6) + expect_true(is.nan(df7_local[df7_local$name == "ID2",][1, 2])) + + mockLines3 <- c("{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}", + "{\"name\":\"Justin\", \"age\":1}") + jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(mockLines3, jsonPath3) + df8 <- jsonFile(sqlContext, jsonPath3) + gd3 <- groupBy(df8, "name") + gd3_local <- collect(sum(gd3)) + expect_equal(60, gd3_local[gd3_local$name == "Andy",][1, 2]) + expect_equal(20, gd3_local[gd3_local$name == "Justin",][1, 2]) + + expect_true(abs(collect(agg(df, sd(df$age)))[1, 1] - 7.778175) < 1e-6) + gd3_local <- collect(agg(gd3, var(df8$age))) + expect_equal(162, gd3_local[gd3_local$name == "Justin",][1, 2]) + + # make sure base:: or stats::sd, var are working + expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) + expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + + unlink(jsonPath2) + unlink(jsonPath3) }) test_that("arrange() and orderBy() on a DataFrame", { @@ -1238,7 +1305,7 @@ test_that("mutate(), transform(), rename() and names()", { expect_equal(columns(transformedDF)[4], "newAge2") expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) - # test if transform on local data frames works + # test if base::transform on local data frames works # ensure the proper signature is used - otherwise this will fail to run attach(airquality) result <- transform(Ozone, logOzone = log(Ozone)) @@ -1467,8 +1534,9 @@ test_that("SQL error message is returned from JVM", { expect_equal(grepl("Table not found: blah", retError), TRUE) }) +irisDF <- createDataFrame(sqlContext, iris) + test_that("Method as.data.frame as a synonym for collect()", { - irisDF <- createDataFrame(sqlContext, iris) expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -1503,6 +1571,27 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) +test_that("Method coltypes() to get R's data types of a DataFrame", { + expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) + + data <- data.frame(c1=c(1,2,3), + c2=c(T,F,T), + c3=c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) + + schema <- structType(structField("c1", "byte"), + structField("c3", "boolean"), + structField("c4", "timestamp")) + + # Test primitive types + DF <- createDataFrame(sqlContext, data, schema) + expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + + # Test complex types + x <- createDataFrame(sqlContext, list(list(as.environment( + list("a"="b", "c"="d", "e"="f"))))) + expect_equal(coltypes(x), "map") +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3584b418a71a9..f55beac6c8c07 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -18,10 +18,11 @@ # Worker daemon rLibDir <- Sys.getenv("SPARKR_RLIBDIR") -script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/") +dirs <- strsplit(rLibDir, ",")[[1]] +script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R") # preload SparkR package, speedup worker -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 0c3b0d1f4be20..3ae072beca11b 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -35,10 +35,11 @@ bootTime <- currentTimeSecs() bootElap <- elapsedSecs() rLibDir <- Sys.getenv("SPARKR_RLIBDIR") +dirs <- strsplit(rLibDir, ",")[[1]] # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require # SparkR namespace -.libPaths(c(rLibDir, .libPaths())) +.libPaths(c(dirs, .libPaths())) suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) diff --git a/core/pom.xml b/core/pom.xml index 5e9e758d72b76..37e3f168ab374 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -51,6 +51,10 @@ com.twitter chill-java + + org.apache.xbean + xbean-asm5-shaded + org.apache.hadoop hadoop-client diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java new file mode 100644 index 0000000000000..279639af5d430 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values from 2 + * Datasets. + */ +public interface CoGroupFunction extends Serializable { + Iterable call(K key, Iterator left, Iterator right) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java new file mode 100644 index 0000000000000..e8d999dd00135 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's filter function. + * + * If the function returns true, the element is discarded in the returned Dataset. + */ +public interface FilterFunction extends Serializable { + boolean call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index 23f5fdd43631b..ef0d1824121ec 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -23,5 +23,5 @@ * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterable call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index c48e92f535ff5..14a98a38ef5ab 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -23,5 +23,5 @@ * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2 extends Serializable { - public Iterable call(T1 t1, T2 t2) throws Exception; + Iterable call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java new file mode 100644 index 0000000000000..18a2d733ca70d --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values. + */ +public interface FlatMapGroupFunction extends Serializable { + Iterable call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java new file mode 100644 index 0000000000000..07e54b28fa12c --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's foreach function. + * + * Spark will invoke the call function on each element in the input Dataset. + */ +public interface ForeachFunction extends Serializable { + void call(T t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java new file mode 100644 index 0000000000000..4938a51bcd712 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a function used in Dataset's foreachPartition function. + */ +public interface ForeachPartitionFunction extends Serializable { + void call(Iterator t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java index 38e410c5debe6..c86928dd05408 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -23,5 +23,5 @@ * A zero-argument function that returns an R. */ public interface Function0 extends Serializable { - public R call() throws Exception; + R call() throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java new file mode 100644 index 0000000000000..fd727d64863d7 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. + */ +public interface Function4 extends Serializable { + public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java new file mode 100644 index 0000000000000..3ae6ef44898e1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a map function used in Dataset's map function. + */ +public interface MapFunction extends Serializable { + U call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java new file mode 100644 index 0000000000000..2935f9986a560 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a map function used in GroupedDataset's map function. + */ +public interface MapGroupFunction extends Serializable { + R call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java new file mode 100644 index 0000000000000..6cb569ce0cb6b --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for function used in Dataset's mapPartitions. + */ +public interface MapPartitionsFunction extends Serializable { + Iterable call(Iterator input) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java new file mode 100644 index 0000000000000..ee092d0058f44 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for function used in Dataset's reduce. + */ +public interface ReduceFunction extends Serializable { + T call(T v1, T v2) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index ee82d679935c0..a1a1fb01426a0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -125,7 +125,7 @@ public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { partitionLengths = new long[numPartitions]; - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } @@ -155,9 +155,10 @@ public void write(Iterator> records) throws IOException { writer.commitAndClose(); } - partitionLengths = - writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId)); - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + File tmp = Utils.tempFileWith(output); + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 6a0a89e81c321..744c3008ca50e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -41,7 +41,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; -import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -53,7 +53,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; -import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -206,8 +206,10 @@ void closeAndWriteOutput() throws IOException { final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; + final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final File tmp = Utils.tempFileWith(output); try { - partitionLengths = mergeSpills(spills); + partitionLengths = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && ! spill.file.delete()) { @@ -215,7 +217,7 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @@ -248,8 +250,7 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills) throws IOException { - final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index 3b4ae2ed354b8..9cc5c79f67346 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -122,3 +122,7 @@ stroke: #52C366; stroke-width: 2px; } + +.tooltip-inner { + white-space: pre-wrap; +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 7421821e2601b..4bbd0b038c00f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -96,7 +96,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private def assertNotStopped(): Unit = { if (stopped.get()) { - throw new IllegalStateException("Cannot call methods on a stopped SparkContext") + val activeContext = SparkContext.activeContext.get() + val activeCreationSite = + if (activeContext == null) { + "(No active SparkContext.)" + } else { + activeContext.creationSite.longForm + } + throw new IllegalStateException( + s"""Cannot call methods on a stopped SparkContext. + |This stopped SparkContext was created at: + | + |${creationSite.longForm} + | + |The currently active SparkContext was created at: + | + |$activeCreationSite + """.stripMargin) } } @@ -863,10 +879,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], - classOf[String], - classOf[String], + classOf[Text], + classOf[Text], updateConf, - minPartitions).setName(path) + minPartitions).setName(path).map(record => (record._1.toString, record._2.toString)) } /** @@ -1787,10 +1803,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * has overridden the call site using `setCallSite()`, this will return the user's version. */ private[spark] def getCallSite(): CallSite = { - Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite => - val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("") - CallSite(shortCallSite, longCallSite) - }.getOrElse(Utils.getCallSite()) + val callSite = Utils.getCallSite() + CallSite( + Option(getLocalProperty(CallSite.SHORT_FORM)).getOrElse(callSite.shortForm), + Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse(callSite.longForm) + ) } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index b7e72d4d0ed0b..8b3be0da2c8c4 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -113,6 +113,7 @@ private[spark] object RBackend extends Logging { val dos = new DataOutputStream(new FileOutputStream(f)) dos.writeInt(boundPort) dos.writeInt(listenPort) + SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.close() f.renameTo(new File(path)) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 6b418e908cb53..7509b3d3f44bb 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -400,14 +400,14 @@ private[r] object RRDD { val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir + "/SparkR/worker/" + script + val rExecScript = rLibDir(0) + "/SparkR/worker/" + script val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. // This is set by R CMD check as startup.Rs // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) // and confuses worker script which tries to load a non-existent file pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir) + pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index fd5646b5b6372..16157414fd120 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -23,6 +23,10 @@ import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} private[spark] object RUtils { + // Local path where R binary packages built from R source code contained in the spark + // packages specified with "--packages" or "--jars" command line option reside. + var rPackages: Option[String] = None + /** * Get the SparkR package path in the local spark distribution. */ @@ -34,11 +38,15 @@ private[spark] object RUtils { } /** - * Get the SparkR package path in various deployment modes. + * Get the list of paths for R packages in various deployment modes, of which the first + * path is for the SparkR package itself. The second path is for R packages built as + * part of Spark Packages, if any exist. Spark Packages can be provided through the + * "--packages" or "--jars" command line options. + * * This assumes that Spark properties `spark.master` and `spark.submit.deployMode` * and environment variable `SPARK_HOME` are set. */ - def sparkRPackagePath(isDriver: Boolean): String = { + def sparkRPackagePath(isDriver: Boolean): Seq[String] = { val (master, deployMode) = if (isDriver) { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) @@ -51,15 +59,30 @@ private[spark] object RUtils { val isYarnClient = master != null && master.contains("yarn") && deployMode == "client" // In YARN mode, the SparkR package is distributed as an archive symbolically - // linked to the "sparkr" file in the current directory. Note that this does not apply - // to the driver in client mode because it is run outside of the cluster. + // linked to the "sparkr" file in the current directory and additional R packages + // are distributed as an archive symbolically linked to the "rpkg" file in the + // current directory. + // + // Note that this does not apply to the driver in client mode because it is run + // outside of the cluster. if (isYarnCluster || (isYarnClient && !isDriver)) { - new File("sparkr").getAbsolutePath + val sparkRPkgPath = new File("sparkr").getAbsolutePath + val rPkgPath = new File("rpkg") + if (rPkgPath.exists()) { + Seq(sparkRPkgPath, rPkgPath.getAbsolutePath) + } else { + Seq(sparkRPkgPath) + } } else { // Otherwise, assume the package is local // TODO: support this for Mesos - localSparkRPackagePath.getOrElse { - throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + val sparkRPkgPath = localSparkRPackagePath.getOrElse { + throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") + } + if (!rPackages.isEmpty) { + Seq(sparkRPkgPath, rPackages.get) + } else { + Seq(sparkRPkgPath) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 6840a3ae831f0..a039d543c35e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -47,7 +47,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) + private val transportContext: TransportContext = + new TransportContext(transportConf, blockHandler, true) private var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 7d160b6790eaa..d46dc87a92c97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -100,20 +100,29 @@ private[deploy] object RPackageUtils extends Logging { * Runs the standard R package installation code to build the R package from source. * Multiple runs don't cause problems. */ - private def rPackageBuilder(dir: File, printStream: PrintStream, verbose: Boolean): Boolean = { + private def rPackageBuilder( + dir: File, + printStream: PrintStream, + verbose: Boolean, + libDir: String): Boolean = { // this code should be always running on the driver. - val pathToSparkR = RUtils.localSparkRPackagePath.getOrElse( - throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")) val pathToPkg = Seq(dir, "R", "pkg").mkString(File.separator) - val installCmd = baseInstallCmd ++ Seq(pathToSparkR, pathToPkg) + val installCmd = baseInstallCmd ++ Seq(libDir, pathToPkg) if (verbose) { print(s"Building R package with the command: $installCmd", printStream) } try { val builder = new ProcessBuilder(installCmd.asJava) builder.redirectErrorStream(true) + + // Put the SparkR package directory into R library search paths in case this R package + // may depend on SparkR. val env = builder.environment() - env.clear() + val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) + env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) + env.put("R_PROFILE_USER", + Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + val process = builder.start() new RedirectThread(process.getInputStream, printStream, "redirect R packaging").start() process.waitFor() == 0 @@ -170,8 +179,11 @@ private[deploy] object RPackageUtils extends Logging { if (checkManifestForR(jar)) { print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) val rSource = extractRFolder(jar, printStream, verbose) + if (RUtils.rPackages.isEmpty) { + RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath) + } try { - if (!rPackageBuilder(rSource, printStream, verbose)) { + if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) { print(s"ERROR: Failed to build R package in $file.", printStream) print(RJarDoc, printStream) } @@ -208,7 +220,7 @@ private[deploy] object RPackageUtils extends Logging { } } - /** Zips all the libraries found with SparkR in the R/lib directory for distribution with Yarn. */ + /** Zips all the R libraries built for distribution to the cluster. */ private[deploy] def zipRLibraries(dir: File, name: String): File = { val filesToBundle = listFilesRecursively(dir, Seq(".zip")) // create a zip file from scratch, do not append to existing file. diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index ed183cf16a9cb..661f7317c674b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -82,9 +82,10 @@ object RRunner { val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) - env.put("SPARKR_PACKAGE_DIR", rPackageDir) + // Put the R package directories into an env variable of comma-separated paths + env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", - Seq(rPackageDir, "SparkR", "profile", "general.R").mkString(File.separator)) + Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 84ae122f44370..2e912b59afdb8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -39,7 +39,7 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} -import org.apache.spark.{SparkUserAppException, SPARK_VERSION} +import org.apache.spark.{SparkException, SparkUserAppException, SPARK_VERSION} import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -83,6 +83,7 @@ object SparkSubmit { private val PYSPARK_SHELL = "pyspark-shell" private val SPARKR_SHELL = "sparkr-shell" private val SPARKR_PACKAGE_ARCHIVE = "sparkr.zip" + private val R_PACKAGE_ARCHIVE = "rpkg.zip" private val CLASS_NOT_FOUND_EXIT_STATUS = 101 @@ -362,22 +363,46 @@ object SparkSubmit { } } - // In YARN mode for an R app, add the SparkR package archive to archives - // that can be distributed with the job + // In YARN mode for an R app, add the SparkR package archive and the R package + // archive containing all of the built R libraries to archives so that they can + // be distributed with the job if (args.isR && clusterManager == YARN) { - val rPackagePath = RUtils.localSparkRPackagePath - if (rPackagePath.isEmpty) { + val sparkRPackagePath = RUtils.localSparkRPackagePath + if (sparkRPackagePath.isEmpty) { printErrorAndExit("SPARK_HOME does not exist for R application in YARN mode.") } - val rPackageFile = - RPackageUtils.zipRLibraries(new File(rPackagePath.get), SPARKR_PACKAGE_ARCHIVE) - if (!rPackageFile.exists()) { + val sparkRPackageFile = new File(sparkRPackagePath.get, SPARKR_PACKAGE_ARCHIVE) + if (!sparkRPackageFile.exists()) { printErrorAndExit(s"$SPARKR_PACKAGE_ARCHIVE does not exist for R application in YARN mode.") } - val localURI = Utils.resolveURI(rPackageFile.getAbsolutePath) + val sparkRPackageURI = Utils.resolveURI(sparkRPackageFile.getAbsolutePath).toString + // Distribute the SparkR package. // Assigns a symbol link name "sparkr" to the shipped package. - args.archives = mergeFileLists(args.archives, localURI.toString + "#sparkr") + args.archives = mergeFileLists(args.archives, sparkRPackageURI + "#sparkr") + + // Distribute the R package archive containing all the built R packages. + if (!RUtils.rPackages.isEmpty) { + val rPackageFile = + RPackageUtils.zipRLibraries(new File(RUtils.rPackages.get), R_PACKAGE_ARCHIVE) + if (!rPackageFile.exists()) { + printErrorAndExit("Failed to zip all the built R packages.") + } + + val rPackageURI = Utils.resolveURI(rPackageFile.getAbsolutePath).toString + // Assigns a symbol link name "rpkg" to the shipped package. + args.archives = mergeFileLists(args.archives, rPackageURI + "#rpkg") + } + } + + // TODO: Support distributing R packages with standalone cluster + if (args.isR && clusterManager == STANDALONE && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with standalone cluster is not supported.") + } + + // TODO: Support SparkR with mesos cluster + if (args.isR && clusterManager == MESOS) { + printErrorAndExit("SparkR is not supported for Mesos cluster.") } // If we're running a R app, set the main class to our specific R runner @@ -521,8 +546,19 @@ object SparkSubmit { sysProps.put("spark.yarn.isPython", "true") } if (args.principal != null) { - require(args.keytab != null, "Keytab must be specified when the keytab is specified") - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + require(args.keytab != null, "Keytab must be specified when principal is specified") + if (!new File(args.keytab).exists()) { + throw new SparkException(s"Keytab file: ${args.keytab} does not exist") + } else { + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sysProps.put("spark.yarn.keytab", args.keytab) + sysProps.put("spark.yarn.principal", args.principal) + + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 25ea6925434ab..afab362e213b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.client import java.util.concurrent._ +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.util.control.NonFatal @@ -49,9 +50,9 @@ private[spark] class AppClient( private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var endpoint: RpcEndpointRef = null - private var appId: String = null - @volatile private var registered = false + private val endpoint = new AtomicReference[RpcEndpointRef] + private val appId = new AtomicReference[String] + private val registered = new AtomicBoolean(false) private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { @@ -59,16 +60,17 @@ private[spark] class AppClient( private var master: Option[RpcEndpointRef] = None // To avoid calling listener.disconnected() multiple times private var alreadyDisconnected = false - @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times - @volatile private var registerMasterFutures: Array[JFuture[_]] = null - @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + // To avoid calling listener.dead() multiple times + private val alreadyDead = new AtomicBoolean(false) + private val registerMasterFutures = new AtomicReference[Array[JFuture[_]]] + private val registrationRetryTimer = new AtomicReference[JScheduledFuture[_]] // A thread pool for registering with masters. Because registering with a master is a blocking // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same // time so that we can register with all masters. private val registerMasterThreadPool = new ThreadPoolExecutor( 0, - masterRpcAddresses.size, // Make sure we can register with all masters at the same time + masterRpcAddresses.length, // Make sure we can register with all masters at the same time 60L, TimeUnit.SECONDS, new SynchronousQueue[Runnable](), ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) @@ -77,6 +79,11 @@ private[spark] class AppClient( private val registrationRetryThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") + // A thread pool to perform receive then reply actions in a thread so as not to block the + // event loop. + private val askAndReplyThreadPool = + ThreadUtils.newDaemonCachedThreadPool("appclient-receive-and-reply-threadpool") + override def onStart(): Unit = { try { registerWithMaster(1) @@ -95,7 +102,7 @@ private[spark] class AppClient( for (masterAddress <- masterRpcAddresses) yield { registerMasterThreadPool.submit(new Runnable { override def run(): Unit = try { - if (registered) { + if (registered.get) { return } logInfo("Connecting to master " + masterAddress.toSparkURL + "...") @@ -118,22 +125,22 @@ private[spark] class AppClient( * nthRetry means this is the nth attempt to register with master. */ private def registerWithMaster(nthRetry: Int) { - registerMasterFutures = tryRegisterAllMasters() - registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + registerMasterFutures.set(tryRegisterAllMasters()) + registrationRetryTimer.set(registrationRetryThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = { Utils.tryOrExit { - if (registered) { - registerMasterFutures.foreach(_.cancel(true)) + if (registered.get) { + registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures.get.foreach(_.cancel(true)) registerWithMaster(nthRetry + 1) } } } - }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) } /** @@ -158,10 +165,10 @@ private[spark] class AppClient( // RegisteredApplications due to an unstable network. // 2. Receive multiple RegisteredApplication from different masters because the master is // changing. - appId = appId_ - registered = true + appId.set(appId_) + registered.set(true) master = Some(masterRef) - listener.connected(appId) + listener.connected(appId.get) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) @@ -173,7 +180,7 @@ private[spark] class AppClient( cores)) // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not // guaranteed), `ExecutorStateChanged` may be sent to a dead master. - sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) + sendToMaster(ExecutorStateChanged(appId.get, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -188,19 +195,19 @@ private[spark] class AppClient( logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) master = Some(masterRef) alreadyDisconnected = false - masterRef.send(MasterChangeAcknowledged(appId)) + masterRef.send(MasterChangeAcknowledged(appId.get)) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopAppClient => markDead("Application has been stopped.") - sendToMaster(UnregisterApplication(appId)) + sendToMaster(UnregisterApplication(appId.get)) context.reply(true) stop() case r: RequestExecutors => master match { - case Some(m) => context.reply(m.askWithRetry[Boolean](r)) + case Some(m) => askAndReplyAsync(m, context, r) case None => logWarning("Attempted to request executors before registering with Master.") context.reply(false) @@ -208,13 +215,32 @@ private[spark] class AppClient( case k: KillExecutors => master match { - case Some(m) => context.reply(m.askWithRetry[Boolean](k)) + case Some(m) => askAndReplyAsync(m, context, k) case None => logWarning("Attempted to kill executors before registering with Master.") context.reply(false) } } + private def askAndReplyAsync[T]( + endpointRef: RpcEndpointRef, + context: RpcCallContext, + msg: T): Unit = { + // Create a thread to ask a message and reply with the result. Allow thread to be + // interrupted during shutdown, otherwise context must be notified of NonFatal errors. + askAndReplyThreadPool.execute(new Runnable { + override def run(): Unit = { + try { + context.reply(endpointRef.askWithRetry[Boolean](msg)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(t) => + context.sendFailure(t) + } + } + }) + } + override def onDisconnected(address: RpcAddress): Unit = { if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") @@ -239,38 +265,39 @@ private[spark] class AppClient( } def markDead(reason: String) { - if (!alreadyDead) { + if (!alreadyDead.get) { listener.dead(reason) - alreadyDead = true + alreadyDead.set(true) } } override def onStop(): Unit = { - if (registrationRetryTimer != null) { - registrationRetryTimer.cancel(true) + if (registrationRetryTimer.get != null) { + registrationRetryTimer.get.cancel(true) } registrationRetryThread.shutdownNow() - registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() + askAndReplyThreadPool.shutdownNow() } } def start() { // Just launch an rpcEndpoint; it will call back into the listener. - endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) + endpoint.set(rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv))) } def stop() { - if (endpoint != null) { + if (endpoint.get != null) { try { val timeout = RpcUtils.askRpcTimeout(conf) - timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) + timeout.awaitResult(endpoint.get.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - endpoint = null + endpoint.set(null) } } @@ -281,8 +308,8 @@ private[spark] class AppClient( * @return whether the request is acknowledged. */ def requestTotalExecutors(requestedTotal: Int): Boolean = { - if (endpoint != null && appId != null) { - endpoint.askWithRetry[Boolean](RequestExecutors(appId, requestedTotal)) + if (endpoint.get != null && appId.get != null) { + endpoint.get.askWithRetry[Boolean](RequestExecutors(appId.get, requestedTotal)) } else { logWarning("Attempted to request executors before driver fully initialized.") false @@ -294,8 +321,8 @@ private[spark] class AppClient( * @return whether the kill request is acknowledged. */ def killExecutors(executorIds: Seq[String]): Boolean = { - if (endpoint != null && appId != null) { - endpoint.askWithRetry[Boolean](KillExecutors(appId, executorIds)) + if (endpoint.get != null && appId.get != null) { + endpoint.get.askWithRetry[Boolean](KillExecutors(appId.get, executorIds)) } else { logWarning("Attempted to kill executors before driver fully initialized.") false diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 6174fc11f83d8..e41554a5a6d26 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -28,14 +28,17 @@ import org.apache.spark.ui.JettyUtils._ * Web UI server for the standalone master. */ private[master] -class MasterWebUI(val master: Master, requestedPort: Int) +class MasterWebUI( + val master: Master, + requestedPort: Int, + customMasterPage: Option[MasterPage] = None) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) - val masterPage = new MasterPage(this) + val masterPage = customMasterPage.getOrElse(new MasterPage(this)) initialize() diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index e8ef60bd5428a..bc67fd460d9a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -46,7 +46,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") val schedulerHeaders = Seq("Scheduler property", "Value") val commandEnvHeaders = Seq("Command environment variable", "Value") val launchedHeaders = Seq("Launched property", "Value") - val commandHeaders = Seq("Comamnd property", "Value") + val commandHeaders = Seq("Command property", "Value") val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count") val driverDescription = Iterable.apply(driverState.description) val submissionState = Iterable.apply(driverState.submissionState) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 1ba34a11414a2..413408723b54d 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -20,6 +20,7 @@ package org.apache.spark.input import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat @@ -33,14 +34,13 @@ import org.apache.hadoop.mapreduce.TaskAttemptContext */ private[spark] class WholeTextFileInputFormat - extends CombineFileInputFormat[String, String] with Configurable { + extends CombineFileInputFormat[Text, Text] with Configurable { override protected def isSplitable(context: JobContext, file: Path): Boolean = false override def createRecordReader( split: InputSplit, - context: TaskAttemptContext): RecordReader[String, String] = { - + context: TaskAttemptContext): RecordReader[Text, Text] = { val reader = new ConfigurableCombineFileRecordReader(split, context, classOf[WholeTextFileRecordReader]) reader.setConf(getConf) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 31bde8a78f3c6..b56b2aa88a414 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -49,7 +49,7 @@ private[spark] class WholeTextFileRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) - extends RecordReader[String, String] with Configurable { + extends RecordReader[Text, Text] with Configurable { private[this] val path = split.getPath(index) private[this] val fs = path.getFileSystem( @@ -58,8 +58,8 @@ private[spark] class WholeTextFileRecordReader( // True means the current file has been processed, then skip it. private[this] var processed = false - private[this] val key = path.toString - private[this] var value: String = null + private[this] val key: Text = new Text(path.toString) + private[this] var value: Text = null override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} @@ -67,9 +67,9 @@ private[spark] class WholeTextFileRecordReader( override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey: String = key + override def getCurrentKey: Text = key - override def getCurrentValue: String = value + override def getCurrentValue: Text = value override def nextKeyValue(): Boolean = { if (!processed) { @@ -83,7 +83,7 @@ private[spark] class WholeTextFileRecordReader( ByteStreams.toByteArray(fileIn) } - value = new Text(innerBuffer).toString + value = new Text(innerBuffer) Closeables.close(fileIn, false) processed = true true diff --git a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala index 3ea984c501e02..a5d41a1eeb479 100644 --- a/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala +++ b/core/src/main/scala/org/apache/spark/launcher/LauncherBackend.scala @@ -21,7 +21,7 @@ import java.net.{InetAddress, Socket} import org.apache.spark.SPARK_VERSION import org.apache.spark.launcher.LauncherProtocol._ -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} /** * A class that can be used to talk to a launcher server. Users should extend this class to @@ -88,12 +88,20 @@ private[spark] abstract class LauncherBackend { */ protected def onDisconnected() : Unit = { } + private def fireStopRequest(): Unit = { + val thread = LauncherBackend.threadFactory.newThread(new Runnable() { + override def run(): Unit = Utils.tryLogNonFatalError { + onStopRequest() + } + }) + thread.start() + } private class BackendConnection(s: Socket) extends LauncherConnection(s) { override protected def handle(m: Message): Unit = m match { case _: Stop => - onStopRequest() + fireStopRequest() case _ => throw new IllegalArgumentException(s"Unexpected message type: ${m.getClass().getName()}") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 0453614f6a1d3..7db583468792e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,6 +213,12 @@ class HadoopRDD[K, V]( val inputMetrics = context.taskMetrics.getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.inputSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -250,6 +256,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { + SqlNewHadoopRDD.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 9c4b70844bdbe..d1960990da0fe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -28,12 +28,11 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.input.WholeTextFileInputFormat import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -59,7 +58,6 @@ private[spark] class NewHadoopPartition( * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. - * @param conf The Hadoop configuration. */ @DeveloperApi class NewHadoopRDD[K, V]( @@ -282,32 +280,3 @@ private[spark] object NewHadoopRDD { } } } - -private[spark] class WholeTextFileRDD( - sc : SparkContext, - inputFormatClass: Class[_ <: WholeTextFileInputFormat], - keyClass: Class[String], - valueClass: Class[String], - conf: Configuration, - minPartitions: Int) - extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { - - override def getPartitions: Array[Partition] = { - val inputFormat = inputFormatClass.newInstance - val conf = getConf - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = newJobContext(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[Partition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } -} - diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 800ef53cbef07..2aeb5eeaad32c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -705,6 +705,24 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } + /** + * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a + * performance API to be used carefully only if we are sure that the RDD elements are + * serializable and don't require closure cleaning. + * + * @param preservesPartitioning indicates whether the input function preserves the partitioner, + * which should be `false` unless this is a pair RDD and the input function doesn't modify + * the keys. + */ + private[spark] def mapPartitionsInternal[U: ClassTag]( + f: Iterator[T] => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter), + preservesPartitioning) + } + /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala new file mode 100644 index 0000000000000..e3f14fe7ef0f8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.{Text, Writable} +import org.apache.hadoop.mapreduce.InputSplit + +import org.apache.spark.{Partition, SparkContext} +import org.apache.spark.input.WholeTextFileInputFormat + +/** + * An RDD that reads a bunch of text files in, and each text file becomes one record. + */ +private[spark] class WholeTextFileRDD( + sc : SparkContext, + inputFormatClass: Class[_ <: WholeTextFileInputFormat], + keyClass: Class[Text], + valueClass: Class[Text], + conf: Configuration, + minPartitions: Int) + extends NewHadoopRDD[Text, Text](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + val conf = getConf + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + inputFormat.setMinPartitions(jobContext, minPartitions) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index c72b588db57fe..464027f07cc88 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -21,8 +21,6 @@ import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal -import com.google.common.annotations.VisibleForTesting - import org.apache.spark.{Logging, SparkException} import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} @@ -193,8 +191,10 @@ private[netty] class Inbox( def isEmpty: Boolean = inbox.synchronized { messages.isEmpty } - /** Called when we are dropping a message. Test cases override this to test message dropping. */ - @VisibleForTesting + /** + * Called when we are dropping a message. Test cases override this to test message dropping. + * Exposed for testing. + */ protected def onDrop(message: InboxMessage): Unit = { logWarning(s"Drop $message because $endpointRef is stopped") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 42c6788773b7a..b2e9a97129f08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,9 +19,10 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import org.roaringbitmap.RoaringBitmap + import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils -import org.roaringbitmap.RoaringBitmap /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the @@ -172,6 +173,9 @@ private[spark] object HighlyCompressedMapStatus { var i = 0 var numNonEmptyBlocks: Int = 0 var totalSize: Long = 0 + // From a compression standpoint, it shouldn't matter whether we track empty or non-empty + // blocks. From a performance standpoint, we benefit from tracking empty blocks because + // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length while (i < totalNumBlocks) { @@ -189,8 +193,8 @@ private[spark] object HighlyCompressedMapStatus { } else { 0 } - emptyBlocks.runOptimize() emptyBlocks.trim() + emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 43d7d80b7aae1..5f136690f456c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -473,6 +473,7 @@ private[spark] class TaskSchedulerImpl( // If the host mapping still exists, it means we don't know the loss reason for the // executor. So call removeExecutor() to update tasks running on that executor when // the real loss reason is finally known. + logError(s"Actual reason for lost executor $executorId: ${reason.message}") removeExecutor(executorId, reason) case None => diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f71d98feac050..3373caf0d15eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -269,7 +269,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. * - * @return Whether executor was alive. + * @return Whether executor should be disabled */ protected def disableExecutor(executorId: String): Boolean = { val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { @@ -277,7 +277,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorsPendingLossReason += executorId true } else { - false + // Returns true for explicitly killed executors, we also need to get pending loss reasons; + // For others return false. + executorsPendingToRemove.contains(executorId) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 05d9bc92f228b..5105475c760e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -191,17 +191,19 @@ private[spark] class SparkDeploySchedulerBackend( } private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - stopping = true + try { + stopping = true - launcherBackend.setState(finalState) - launcherBackend.close() + super.stop() + client.stop() - super.stop() - client.stop() - - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d10a77f8e5c78..2de9b6a651692 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -101,6 +101,10 @@ private[spark] class CoarseMesosSchedulerBackend( private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + // A client for talking to the external shuffle service, if it is a private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { @@ -249,48 +253,56 @@ private[spark] class CoarseMesosSchedulerBackend( val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse - val taskId = newMesosTaskId() - taskIdToSlaveId.put(taskId, slaveId) - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) - val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) - .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse.asJava) - .addAllResources(memResourcesToUse.asJava) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + if (meetsConstraints) { + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + mem >= calculateTotalMemory(sc) && + cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId)) { + // Launch an executor on the slave + val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) + totalCoresAcquired += cpusToUse + val taskId = newMesosTaskId() + taskIdToSlaveId.put(taskId, slaveId) + slaveIdsWithExecutors += slaveId + coresByTaskId(taskId) = cpusToUse + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) + .setName("Task " + taskId) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + } + + // Accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname + d.launchTasks( + Collections.singleton(offer.getId), + Collections.singleton(taskBuilder.build()), filters) + } else { + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } - - // accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + // This offer does not meet constraints. We don't need to see it again. + // Decline the offer for a long period of time. + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + d.declineOffer(offer.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index aaffac604a885..281965a5981bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -63,6 +63,10 @@ private[spark] class MesosSchedulerBackend( private[this] val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + @volatile var appId: String = _ override def start() { @@ -212,29 +216,47 @@ private[spark] class MesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { - // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.asScala.partition { o => + // Fail first on offers with unmet constraints + val (offersMatchingConstraints, offersNotMatchingConstraints) = + offers.asScala.partition { o => + val offerAttributes = toAttributeMap(o.getAttributesList) + val meetsConstraints = + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + + // add some debug messaging + if (!meetsConstraints) { + val id = o.getId.getValue + logDebug(s"Declining offer: $id with attributes: $offerAttributes") + } + + meetsConstraints + } + + // These offers do not meet constraints. We don't need to see them again. + // Decline the offer for a long period of time. + offersNotMatchingConstraints.foreach { o => + d.declineOffer(o.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + } + + // Of the matching constraints, see which ones give us enough memory and cores + val (usableOffers, unUsableOffers) = offersMatchingConstraints.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check if all constraints are satisfield - // 1. Attribute constraints - // 2. Memory requirements - // 3. CPU requirements - need at least 1 for executor, 1 for task - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + // check offers for + // 1. Memory requirements + // 2. CPU requirements - need at least 1 for executor, 1 for task val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) - val meetsRequirements = - (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (meetsMemoryRequirements && meetsCPURequirements) || (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) - - // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" - val id = o.getId.getValue - logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + + s"$offerAttributes mem: $mem cpu: $cpus") meetsRequirements } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 860c8e097b3b9..721861fbbc517 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -336,4 +336,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } } + protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { + sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 183639c5407cf..d4152305b49eb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -21,24 +21,25 @@ import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.esotericsoftware.kryo.{Kryo, KryoException} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} +import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} + import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.collection.{BitSet, CompactBuffer} +import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} -import org.roaringbitmap.RoaringBitmap - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -363,7 +364,11 @@ private[serializer] object KryoSerializer { classOf[CompressedMapStatus], classOf[HighlyCompressedMapStatus], classOf[RoaringBitmap], - classOf[BitSet], + classOf[RoaringArray], + classOf[RoaringArray.Element], + classOf[Array[RoaringArray.Element]], + classOf[ArrayContainer], + classOf[BitmapContainer], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Byte]], diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index cd253a78c2b19..39fadd8783518 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -21,13 +21,13 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ -import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.{Logging, SparkConf, SparkEnv} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -84,17 +84,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) - // Because of previous failures, the shuffle file may already exist on this machine. - // If so, remove it. - if (blockFile.exists) { - if (blockFile.delete()) { - logInfo(s"Removed existing shuffle file $blockFile") - } else { - logWarning(s"Failed to remove existing shuffle file $blockFile") - } - } - blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize, - writeMetrics) + val tmp = Utils.tempFileWith(blockFile) + blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics) } } // Creating the file to write to and creating a disk writer both involve interacting with diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 5e4c2b5d0a5c4..05b1eed7f3bef 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,13 +21,12 @@ import java.io._ import com.google.common.io.ByteStreams -import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils - -import IndexShuffleBlockResolver.NOOP_REDUCE_ID +import org.apache.spark.{SparkEnv, Logging, SparkConf} /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. @@ -40,10 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData(). -private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver +private[spark] class IndexShuffleBlockResolver( + conf: SparkConf, + _blockManager: BlockManager = null) + extends ShuffleBlockResolver with Logging { - private lazy val blockManager = SparkEnv.get.blockManager + private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager) private val transportConf = SparkTransportConf.fromSparkConf(conf) @@ -74,14 +76,69 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } } + /** + * Check whether the given index and data files match each other. + * If so, return the partition lengths in the data file. Otherwise return null. + */ + private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { + // the index file should have `block + 1` longs as offset. + if (index.length() != (blocks + 1) * 8) { + return null + } + val lengths = new Array[Long](blocks) + // Read the lengths of blocks + val in = try { + new DataInputStream(new BufferedInputStream(new FileInputStream(index))) + } catch { + case e: IOException => + return null + } + try { + // Convert the offsets into lengths of each block + var offset = in.readLong() + if (offset != 0L) { + return null + } + var i = 0 + while (i < blocks) { + val off = in.readLong() + lengths(i) = off - offset + offset = off + i += 1 + } + } catch { + case e: IOException => + return null + } finally { + in.close() + } + + // the size of data file should match with index file + if (data.length() == lengths.sum) { + lengths + } else { + null + } + } + /** * Write an index file with the offsets of each block, plus a final offset at the end for the * end of the output file. This will be used by getBlockData to figure out where each block * begins and ends. + * + * It will commit the data and index file as an atomic operation, use the existing ones, or + * replace them with new ones. + * + * Note: the `lengths` will be updated to match the existing index file if use the existing ones. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { + def writeIndexFileAndCommit( + shuffleId: Int, + mapId: Int, + lengths: Array[Long], + dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + val indexTmp = Utils.tempFileWith(indexFile) + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L @@ -93,6 +150,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } { out.close() } + + val dataFile = getDataFile(shuffleId, mapId) + // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure + // the following check and rename are atomic. + synchronized { + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { + // Another attempt for the same task has already written our map outputs successfully, + // so just use the existing partition lengths and delete our temporary map outputs. + System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (dataTmp != null && dataTmp.exists()) { + dataTmp.delete() + } + indexTmp.delete() + } else { + // This is the first successful attempt in writing the map outputs for this task, + // so override any existing index and data files with the ones we wrote. + if (indexFile.exists()) { + indexFile.delete() + } + if (dataFile.exists()) { + dataFile.delete() + } + if (!indexTmp.renameTo(indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + } + } + } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 41df70c602c30..412bf70000da7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.hash +import java.io.IOException + import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus @@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V]( writer.commitAndClose() writer.fileSegment().length } + // rename all shuffle files to final paths + // Note: there is only one ShuffleBlockResolver in executor + shuffleBlockResolver.synchronized { + shuffle.writers.zipWithIndex.foreach { case (writer, i) => + val output = blockManager.diskBlockManager.getFile(writer.blockId) + if (sizes(i) > 0) { + if (output.exists()) { + // Use length of existing file and delete our own temporary one + sizes(i) = output.length() + writer.file.delete() + } else { + // Commit by renaming our temporary file to something the fetcher expects + if (!writer.file.renameTo(output)) { + throw new IOException(s"fail to rename ${writer.file} to $output") + } + } + } else { + if (output.exists()) { + output.delete() + } + } + } + } MapStatus(blockManager.shuffleServerId, sizes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 808317b017a0f..f83cf8859e581 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -20,8 +20,9 @@ package org.apache.spark.shuffle.sort import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.Utils import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -65,11 +66,11 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val tmp = Utils.tempFileWith(output) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, outputFile) - shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 17b521f3e1d41..0fc0fb59d861f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -62,6 +62,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = app.id, name = app.name, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = app.attempts.map { internalAttemptInfo => new ApplicationAttemptInfo( attemptId = internalAttemptInfo.attemptId, @@ -81,6 +85,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = internal.id, name = internal.desc.name, + coresGranted = Some(internal.coresGranted), + maxCores = internal.desc.maxCores, + coresPerExecutor = internal.desc.coresPerExecutor, + memoryPerExecutorMB = Some(internal.desc.memoryPerExecutorMB), attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(internal.startTime), diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2bec64f2ef02b..baddfc50c1a40 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -25,6 +25,10 @@ import org.apache.spark.JobExecutionStatus class ApplicationInfo private[spark]( val id: String, val name: String, + val coresGranted: Option[Int], + val maxCores: Option[Int], + val coresPerExecutor: Option[Int], + val memoryPerExecutorMB: Option[Int], val attempts: Seq[ApplicationAttemptInfo]) class ApplicationAttemptInfo private[spark]( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c374b93766225..661c706af32b1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -21,10 +21,10 @@ import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ -import scala.util.control.NonFatal +import scala.concurrent.{Await, ExecutionContext, Future} import scala.util.Random +import scala.util.control.NonFatal import sun.nio.ch.DirectBuffer @@ -38,9 +38,8 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.{SerializerInstance, Serializer} +import org.apache.spark.serializer.{Serializer, SerializerInstance} import org.apache.spark.shuffle.ShuffleManager -import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues @@ -660,7 +659,7 @@ private[spark] class BlockManager( val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream, - syncWrites, writeMetrics) + syncWrites, writeMetrics, blockId) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 80d426fadc65e..e2dd80f243930 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -34,14 +34,15 @@ import org.apache.spark.util.Utils * reopened again. */ private[spark] class DiskBlockObjectWriter( - file: File, + val file: File, serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics) + writeMetrics: ShuffleWriteMetrics, + val blockId: BlockId = null) extends OutputStream with Logging { diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 96062626b5045..87c1b981e7e13 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDDOperationScope, RDD} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} @DeveloperApi class RDDInfo( @@ -28,6 +28,7 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], + val callSite: CallSite = CallSite.empty, val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { @@ -56,6 +57,7 @@ private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) - new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, + rdd.getStorageLevel, parentIds, rdd.creationSite, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 99085ada9f0af..4608bce202ec8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -102,6 +102,10 @@ private[spark] class SparkUI private ( Iterator(new ApplicationInfo( id = appId, name = appName, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(startTime), diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 77d034fa5ba2c..ca37829216f22 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -21,8 +21,6 @@ import java.util.concurrent.TimeoutException import scala.collection.mutable.{HashMap, HashSet, ListBuffer} -import com.google.common.annotations.VisibleForTesting - import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 51425e599e748..1b34ba9f03c44 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -28,7 +28,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.{InternalAccumulator, SparkConf} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} @@ -70,6 +70,21 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { private val displayPeakExecutionMemory = parent.conf.getBoolean("spark.sql.unsafe.enabled", true) + private def getLocalitySummaryString(stageData: StageUIData): String = { + val localities = stageData.taskData.values.map(_.taskInfo.taskLocality) + val localityCounts = localities.groupBy(identity).mapValues(_.size) + val localityNamesAndCounts = localityCounts.toSeq.map { case (locality, count) => + val localityName = locality match { + case TaskLocality.PROCESS_LOCAL => "Process local" + case TaskLocality.NODE_LOCAL => "Node local" + case TaskLocality.RACK_LOCAL => "Rack local" + case TaskLocality.ANY => "Any" + } + s"$localityName: $count" + } + localityNamesAndCounts.sorted.mkString("; ") + } + def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { val parameterId = request.getParameter("id") @@ -129,6 +144,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Total Time Across All Tasks: {UIUtils.formatDuration(stageData.executorRunTime)} +
  • + Locality Level Summary: + {getLocalitySummaryString(stageData)} +
  • {if (stageData.hasInput) {
  • Input Size / Records: diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 81f168a447ead..24274562657b3 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.{StringBuilder, ListBuffer} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.CallSite /** * A representation of a generic cluster graph used for storing information on RDD operations. @@ -38,7 +39,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: CallSite) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -104,8 +105,8 @@ private[ui] object RDDOperationGraph extends Logging { edges ++= rdd.parentIds.map { parentId => RDDOperationEdge(parentId, rdd.id) } // TODO: differentiate between the intention to cache an RDD and whether it's actually cached - val node = nodes.getOrElseUpdate( - rdd.id, RDDOperationNode(rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE)) + val node = nodes.getOrElseUpdate(rdd.id, RDDOperationNode( + rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE, rdd.callSite)) if (rdd.scope.isEmpty) { // This RDD has no encompassing scope, so we put it directly in the root cluster @@ -177,7 +178,8 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - s"""${node.id} [label="${node.name} [${node.id}]"]""" + val label = s"${node.name} [${node.id}]\n${node.callsite.shortForm}" + s"""${node.id} [label="$label"]""" } /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index 61b5a4cecddce..c20627b056bef 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -20,7 +20,6 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import com.google.common.annotations.VisibleForTesting import org.apache.spark.SparkContext /** @@ -122,8 +121,8 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri * For testing only. Wait until there are no more events in the queue, or until the specified * time has elapsed. Throw `TimeoutException` if the specified time elapsed before the queue * emptied. + * Exposed for testing. */ - @VisibleForTesting @throws(classOf[TimeoutException]) def waitUntilEmpty(timeoutMillis: Long): Unit = { val finishTime = System.currentTimeMillis + timeoutMillis @@ -140,8 +139,8 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri /** * For testing only. Return whether the listener daemon thread is still alive. + * Exposed for testing. */ - @VisibleForTesting def listenerThreadIsAlive: Boolean = listenerThread.isAlive /** diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 1b49dca9dc78b..e27d2e6c94f7b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -21,8 +21,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.{Map, Set} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -325,11 +325,11 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") -private class ReturnStatementFinder extends ClassVisitor(ASM4) { +private class ReturnStatementFinder extends ClassVisitor(ASM5) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name.contains("apply")) { - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitTypeInsn(op: Int, tp: String) { if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { throw new ReturnStatementInClosureException @@ -337,7 +337,7 @@ private class ReturnStatementFinder extends ClassVisitor(ASM4) { } } } else { - new MethodVisitor(ASM4) {} + new MethodVisitor(ASM5) {} } } } @@ -361,7 +361,7 @@ private[util] class FieldAccessFinder( findTransitively: Boolean, specificMethod: Option[MethodIdentifier[_]] = None, visitedMethods: Set[MethodIdentifier[_]] = Set.empty) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { override def visitMethod( access: Int, @@ -376,7 +376,7 @@ private[util] class FieldAccessFinder( return null } - new MethodVisitor(ASM4) { + new MethodVisitor(ASM5) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { @@ -385,7 +385,8 @@ private[util] class FieldAccessFinder( } } - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) { // Check for calls a getter method for a variable in an interpreter wrapper object. // This means that the corresponding field will be accessed, so we should save it. @@ -408,7 +409,7 @@ private[util] class FieldAccessFinder( } } -private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { +private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM5) { var myName: String = null // TODO: Recursively find inner closures that we indirectly reference, e.g. @@ -423,9 +424,9 @@ private class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, - desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { val argTypes = Type.getArgumentTypes(desc) if (op == INVOKESPECIAL && name == "" && argTypes.length > 0 && argTypes(0).toString.startsWith("L") // is it an object? diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ee2eb58cf5e2a..c9beeb25e05af 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -398,6 +398,7 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ + ("Callsite" -> callsiteToJson(rddInfo.callSite)) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ @@ -407,6 +408,11 @@ private[spark] object JsonProtocol { ("Disk Size" -> rddInfo.diskSize) } + def callsiteToJson(callsite: CallSite): JValue = { + ("Short Form" -> callsite.shortForm) ~ + ("Long Form" -> callsite.longForm) + } + def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ @@ -851,6 +857,9 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) + val callsite = Utils.jsonOption(json \ "Callsite") + .map(callsiteFromJson) + .getOrElse(CallSite.empty) val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -863,7 +872,7 @@ private[spark] object JsonProtocol { .getOrElse(json \ "Tachyon Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, scope) + val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, callsite, scope) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize rddInfo.externalBlockStoreSize = externalBlockStoreSize @@ -871,6 +880,12 @@ private[spark] object JsonProtocol { rddInfo } + def callsiteFromJson(json: JValue): CallSite = { + val shortForm = (json \ "Short Form").extract[String] + val longForm = (json \ "Long Form").extract[String] + CallSite(shortForm, longForm) + } + def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5a976ee839b1e..1b3acb8ef7f51 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,8 +21,8 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ +import java.util.{Locale, Properties, Random, UUID} import javax.net.ssl.HttpsURLConnection import scala.collection.JavaConverters._ @@ -30,7 +30,7 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.{Failure, Success, Try} +import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} @@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ - import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} @@ -57,6 +56,7 @@ private[spark] case class CallSite(shortForm: String, longForm: String) private[spark] object CallSite { val SHORT_FORM = "callSite.short" val LONG_FORM = "callSite.long" + val empty = CallSite("", "") } /** @@ -2168,6 +2168,13 @@ private[spark] object Utils extends Logging { val resource = createResource try f.apply(resource) finally resource.close() } + + /** + * Returns a path of temporary file which is in the same directory with `path`. + */ + def tempFileWith(path: File): File = { + new File(path.getAbsolutePath + "." + UUID.randomUUID()) + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 85c5bdbfcebc0..7ab67fc3a2de9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -17,21 +17,14 @@ package org.apache.spark.util.collection -import java.io.{Externalizable, ObjectInput, ObjectOutput} - -import org.apache.spark.util.{Utils => UUtils} - - /** * A simple, fixed-size bit set implementation. This implementation is fast because it avoids * safety/bound checking. */ -class BitSet(private[this] var numBits: Int) extends Externalizable { +class BitSet(numBits: Int) extends Serializable { - private var words = new Array[Long](bit2words(numBits)) - private def numWords = words.length - - def this() = this(0) + private val words = new Array[Long](bit2words(numBits)) + private val numWords = words.length /** * Compute the capacity (number of bits) that can be represented @@ -237,19 +230,4 @@ class BitSet(private[this] var numBits: Int) extends Externalizable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 - - override def writeExternal(out: ObjectOutput): Unit = UUtils.tryOrIOException { - out.writeInt(numBits) - words.foreach(out.writeLong(_)) - } - - override def readExternal(in: ObjectInput): Unit = UUtils.tryOrIOException { - numBits = in.readInt() - words = new Array[Long](bit2words(numBits)) - var index = 0 - while (index < words.length) { - words(index) = in.readLong() - index += 1 - } - } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index a44e72b7c16d3..2440139ac95e9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -23,7 +23,6 @@ import java.util.Comparator import scala.collection.mutable.ArrayBuffer import scala.collection.mutable -import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ @@ -608,8 +607,8 @@ private[spark] class ExternalSorter[K, V, C]( * * For now, we just merge all the spilled files in once pass, but this can be modified to * support hierarchical merging. + * Exposed for testing. */ - @VisibleForTesting def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer @@ -639,7 +638,6 @@ private[spark] class ExternalSorter[K, V, C]( * called by the SortShuffleWriter. * * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */ def writePartitionedFile( diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala new file mode 100644 index 0000000000000..0b19861fc41ee --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort + +import java.io.{File, FileInputStream, FileOutputStream} + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.{Mock, MockitoAnnotations} +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkFunSuite} + + +class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ + @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ + + private var tempDir: File = _ + private val conf: SparkConf = new SparkConf(loadDefaults = false) + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + MockitoAnnotations.initMocks(this) + + when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(diskBlockManager.getFile(any[BlockId])).thenAnswer( + new Answer[File] { + override def answer(invocation: InvocationOnMock): File = { + new File(tempDir, invocation.getArguments.head.toString) + } + }) + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + } + + test("commit shuffle files multiple times") { + val lengths = Array[Long](10, 0, 20) + val resolver = new IndexShuffleBlockResolver(conf, blockManager) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val out = new FileOutputStream(dataTmp) + out.write(new Array[Byte](30)) + out.close() + resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + + val dataFile = resolver.getDataFile(1, 2) + assert(dataFile.exists()) + assert(dataFile.length() === 30) + assert(!dataTmp.exists()) + + val dataTmp2 = File.createTempFile("shuffle", null, tempDir) + val out2 = new FileOutputStream(dataTmp2) + val lengths2 = new Array[Long](3) + out2.write(Array[Byte](1)) + out2.write(new Array[Byte](29)) + out2.close() + resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2) + assert(lengths2.toSeq === lengths.toSeq) + assert(dataFile.exists()) + assert(dataFile.length() === 30) + assert(!dataTmp2.exists()) + + // The dataFile should be the previous one + val in = new FileInputStream(dataFile) + val firstByte = new Array[Byte](1) + in.read(firstByte) + assert(firstByte(0) === 0) + + // remove data file + dataFile.delete() + + val dataTmp3 = File.createTempFile("shuffle", null, tempDir) + val out3 = new FileOutputStream(dataTmp3) + val lengths3 = Array[Long](10, 10, 15) + out3.write(Array[Byte](2)) + out3.write(new Array[Byte](34)) + out3.close() + resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3) + assert(lengths3.toSeq != lengths.toSeq) + assert(dataFile.exists()) + assert(dataFile.length() === 35) + assert(!dataTmp2.exists()) + + // The dataFile should be the previous one + val in2 = new FileInputStream(dataFile) + val firstByte2 = new Array[Byte](1) + in2.read(firstByte2) + assert(firstByte2(0) === 2) + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0e0eca515afc1..bc85918c59aab 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -130,7 +130,8 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); @@ -169,9 +170,13 @@ public OutputStream answer(InvocationOnMock invocation) throws Throwable { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + File tmp = (File) invocationOnMock.getArguments()[3]; + mergedOutputFile.delete(); + tmp.renameTo(mergedOutputFile); return null; } - }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + }).when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer( new Answer>() { diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 3bca790f30870..d87a1d2a56d99 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -117,7 +117,8 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 11c3a7be38875..a1c9f6fab8e65 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -130,7 +130,8 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th (Integer) args[3], new CompressStream(), false, - (ShuffleWriteMetrics) args[4] + (ShuffleWriteMetrics) args[4], + (BlockId) args[0] ); } }); diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 4a0877d86f2c6..0de10ae485378 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,12 +17,16 @@ package org.apache.spark +import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier} + import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair @@ -317,6 +321,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(metrics.bytesWritten === metrics.byresRead) assert(metrics.bytesWritten > 0) } + + test("multiple simultaneous attempts for one task (SPARK-8029)") { + sc = new SparkContext("local", "test", conf) + val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val manager = sc.env.shuffleManager + + val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L) + val metricsSystem = sc.env.metricsSystem + val shuffleMapRdd = new MyRDD(sc, 1, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + + // first attempt -- its successful + val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, + new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem, + InternalAccumulator.create(sc))) + val data1 = (1 to 10).map { x => x -> x} + + // second attempt -- also successful. We'll write out different data, + // just to simulate the fact that the records may get written differently + // depending on what gets spilled, what gets combined, etc. + val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, + new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem, + InternalAccumulator.create(sc))) + val data2 = (11 to 20).map { x => x -> x} + + // interleave writes of both attempts -- we want to test that both attempts can occur + // simultaneously, and everything is still OK + + def writeAndClose( + writer: ShuffleWriter[Int, Int])( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + val files = writer.write(iter) + writer.stop(true) + } + val interleaver = new InterleaveIterators( + data1, writeAndClose(writer1), data2, writeAndClose(writer2)) + val (mapOutput1, mapOutput2) = interleaver.run() + + // check that we can read the map output and it has the right data + assert(mapOutput1.isDefined) + assert(mapOutput2.isDefined) + assert(mapOutput1.get.location === mapOutput2.get.location) + assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0)) + + // register one of the map outputs -- doesn't matter which one + mapOutput1.foreach { case mapStatus => + mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + } + + val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, + new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem, + InternalAccumulator.create(sc))) + val readData = reader.read().toIndexedSeq + assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) + + manager.unregisterShuffle(0) + } +} + +/** + * Utility to help tests make sure that we can process two different iterators simultaneously + * in different threads. This makes sure that in your test, you don't completely process data1 with + * f1 before processing data2 with f2 (or vice versa). It adds a barrier so that the functions only + * process one element, before pausing to wait for the other function to "catch up". + */ +class InterleaveIterators[T, R]( + data1: Seq[T], + f1: Iterator[T] => R, + data2: Seq[T], + f2: Iterator[T] => R) { + + require(data1.size == data2.size) + + val barrier = new CyclicBarrier(2) + class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] { + def hasNext: Boolean = sub.hasNext + + def next: E = { + barrier.await() + sub.next() + } + } + + val c1 = new Callable[R] { + override def call(): R = f1(new BarrierIterator(1, data1.iterator)) + } + val c2 = new Callable[R] { + override def call(): R = f2(new BarrierIterator(2, data2.iterator)) + } + + val e: ExecutorService = Executors.newFixedThreadPool(2) + + def run(): (R, R) = { + val future1 = e.submit(c1) + val future2 = e.submit(c2) + val r1 = future1.get() + val r2 = future2.get() + e.shutdown() + (r1, r2) + } } object ShuffleSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 1ed4bae3ca21e..cc30ba223e1c3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -33,8 +33,12 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate +import org.apache.spark.util.ResetSystemProperties -class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { +class RPackageUtilsSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ResetSystemProperties { private val main = MavenCoordinate("a", "b", "c") private val dep1 = MavenCoordinate("a", "dep1", "c") @@ -60,11 +64,9 @@ class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { } } - def beforeAll() { - System.setProperty("spark.testing", "true") - } - override def beforeEach(): Unit = { + super.beforeEach() + System.setProperty("spark.testing", "true") lineBuffer.clear() } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1fd470cd3b01d..42e748ec6d528 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -23,11 +23,12 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -37,10 +38,12 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} class SparkSubmitSuite extends SparkFunSuite with Matchers + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } @@ -367,9 +370,6 @@ class SparkSubmitSuite } test("correctly builds R packages included in a jar with --packages") { - // TODO(SPARK-9603): Building a package to $SPARK_HOME/R/lib is unavailable on Jenkins. - // It's hard to write the test in SparkR (because we can't create the repository dynamically) - /* assume(RUtils.isRInstalled, "R isn't installed on this machine.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -387,7 +387,6 @@ class SparkSubmitSuite rScriptDir) runSparkSubmit(args) } - */ } test("resolves command line argument paths correctly") { diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala new file mode 100644 index 0000000000000..1e5c05a73f8aa --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.client + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.deploy.{ApplicationDescription, Command} +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.master.{ApplicationInfo, Master} +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.Utils + +/** + * End-to-end tests for application client in standalone mode. + */ +class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterAll { + private val numWorkers = 2 + private val conf = new SparkConf() + private val securityManager = new SecurityManager(conf) + + private var masterRpcEnv: RpcEnv = null + private var workerRpcEnvs: Seq[RpcEnv] = null + private var master: Master = null + private var workers: Seq[Worker] = null + + /** + * Start the local cluster. + * Note: local-cluster mode is insufficient because we want a reference to the Master. + */ + override def beforeAll(): Unit = { + super.beforeAll() + masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) + workerRpcEnvs = (0 until numWorkers).map { i => + RpcEnv.create(Worker.SYSTEM_NAME + i, "localhost", 0, conf, securityManager) + } + master = makeMaster() + workers = makeWorkers(10, 2048) + // Wait until all workers register with master successfully + eventually(timeout(60.seconds), interval(10.millis)) { + assert(getMasterState.workers.size === numWorkers) + } + } + + override def afterAll(): Unit = { + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnv.shutdown() + workers.foreach(_.stop()) + master.stop() + workerRpcEnvs = null + masterRpcEnv = null + workers = null + master = null + super.afterAll() + } + + test("interface methods of AppClient using local Master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + ci.client.start() + + // Client should connect with one Master which registers the application + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.connectedIdList.size === 1, "client listener should have one connection") + assert(apps.size === 1, "master should have 1 registered app") + } + + // Send message to Master to request Executors, verify request by change in executor limit + val numExecutorsRequested = 1 + assert(ci.client.requestTotalExecutors(numExecutorsRequested)) + + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.head.getExecutorLimit === numExecutorsRequested, s"executor request failed") + } + + // Send request to kill executor, verify request was made + assert { + val apps = getApplications() + val executorId: String = apps.head.executors.head._2.fullId + ci.client.killExecutors(Seq(executorId)) + } + + // Issue stop command for Client to disconnect from Master + ci.client.stop() + + // Verify Client is marked dead and unregistered from Master + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(ci.listener.deadReasonList.size === 1, "client should have been marked dead") + assert(apps.isEmpty, "master should have 0 registered apps") + } + } + + test("request from AppClient before initialized with master") { + val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) + + // requests to master should fail immediately + assert(ci.client.requestTotalExecutors(3) === false) + } + + // =============================== + // | Utility methods for testing | + // =============================== + + /** Return a SparkConf for applications that want to talk to our Master. */ + private def appConf: SparkConf = { + new SparkConf() + .setMaster(masterRpcEnv.address.toSparkURL) + .setAppName("test") + .set("spark.executor.memory", "256m") + } + + /** Make a master to which our application will send executor requests. */ + private def makeMaster(): Master = { + val master = new Master(masterRpcEnv, masterRpcEnv.address, 0, securityManager, conf) + masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + master + } + + /** Make a few workers that talk to our master. */ + private def makeWorkers(cores: Int, memory: Int): Seq[Worker] = { + (0 until numWorkers).map { i => + val rpcEnv = workerRpcEnvs(i) + val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), + Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) + worker + } + } + + /** Get the Master state */ + private def getMasterState: MasterStateResponse = { + master.self.askWithRetry[MasterStateResponse](RequestMasterState) + } + + /** Get the applictions that are active from Master */ + private def getApplications(): Seq[ApplicationInfo] = { + getMasterState.activeApps + } + + /** Application Listener to collect events */ + private class AppClientCollector extends AppClientListener with Logging { + val connectedIdList = new ArrayBuffer[String] with SynchronizedBuffer[String] + @volatile var disconnectedCount: Int = 0 + val deadReasonList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val execAddedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + val execRemovedList = new ArrayBuffer[String] with SynchronizedBuffer[String] + + def connected(id: String): Unit = { + connectedIdList += id + } + + def disconnected(): Unit = { + synchronized { + disconnectedCount += 1 + } + } + + def dead(reason: String): Unit = { + deadReasonList += reason + } + + def executorAdded( + id: String, + workerId: String, + hostPort: String, + cores: Int, + memory: Int): Unit = { + execAddedList += id + } + + def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { + execRemovedList += id + } + } + + /** Create AppClient and supporting objects */ + private class AppClientInst(masterUrl: String) { + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, securityManager) + private val cmd = new Command(TestExecutor.getClass.getCanonicalName.stripSuffix("$"), + List(), Map(), Seq(), Seq(), Seq()) + private val desc = new ApplicationDescription("AppClientSuite", Some(1), 512, cmd, "ignored") + val listener = new AppClientCollector + val client = new AppClient(rpcEnv, Array(masterUrl), desc, listener, new SparkConf) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala new file mode 100644 index 0000000000000..fba835f054f8a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date + +import scala.io.Source +import scala.language.postfixOps + +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JNothing, JString, JInt} +import org.mockito.Mockito.{mock, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.MasterStateResponse +import org.apache.spark.deploy.DeployTestUtils._ +import org.apache.spark.deploy.master._ +import org.apache.spark.rpc.RpcEnv + + +class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { + + val masterPage = mock(classOf[MasterPage]) + val master = { + val conf = new SparkConf + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + master + } + val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) + + before { + masterWebUI.bind() + } + + after { + masterWebUI.stop() + } + + test("list applications") { + val worker = createWorkerInfo() + val appDesc = createAppDesc() + // use new start date so it isn't filtered by UI + val activeApp = new ApplicationInfo( + new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) + activeApp.addExecutor(worker, 2) + + val workers = Array[WorkerInfo](worker) + val activeApps = Array(activeApp) + val completedApps = Array[ApplicationInfo]() + val activeDrivers = Array[DriverInfo]() + val completedDrivers = Array[DriverInfo]() + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, + activeDrivers, completedDrivers, RecoveryState.ALIVE) + + when(masterPage.getMasterState).thenReturn(stateResponse) + + val resultJson = Source.fromURL( + s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") + .mkString + val parsedJson = parse(resultJson) + val firstApp = parsedJson(0) + + assert(firstApp \ "id" === JString(activeApp.id)) + assert(firstApp \ "name" === JString(activeApp.desc.name)) + assert(firstApp \ "coresGranted" === JInt(2)) + assert(firstApp \ "maxCores" === JInt(4)) + assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) + assert(firstApp \ "coresPerExecutor" === JNothing) + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3816b8c4a09aa..4d6b25455226f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -594,11 +594,17 @@ class DAGSchedulerSuite * @param stageId - The current stageId * @param attemptIdx - The current attempt count */ - private def completeNextResultStageWithSuccess(stageId: Int, attemptIdx: Int): Unit = { + private def completeNextResultStageWithSuccess( + stageId: Int, + attemptIdx: Int, + partitionToResult: Int => Int = _ => 42): Unit = { val stageAttempt = taskSets.last checkStageId(stageId, attemptIdx, stageAttempt) assert(scheduler.stageIdToStage(stageId).isInstanceOf[ResultStage]) - complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map(_ => (Success, 42)).toSeq) + val taskResults = stageAttempt.tasks.zipWithIndex.map { case (task, idx) => + (Success, partitionToResult(idx)) + } + complete(stageAttempt, taskResults.toSeq) } /** @@ -1054,6 +1060,47 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * Run two jobs, with a shared dependency. We simulate a fetch failure in the second job, which + * requires regenerating some outputs of the shared dependency. One key aspect of this test is + * that the second job actually uses a different stage for the shared dependency (a "skipped" + * stage). + */ + test("shuffle fetch failure in a reused shuffle dependency") { + // Run the first job successfully, which creates one shuffle dependency + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(0, 0, 2) + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + assertDataStructuresEmpty() + + // submit another job w/ the shared dependency, and have a fetch failure + val reduce2 = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduce2, Array(0, 1)) + // Note that the stage numbering here is only b/c the shared dependency produces a new, skipped + // stage. If instead it reused the existing stage, then this would be stage 2 + completeNextStageWithFetchFailure(3, 0, shuffleDep) + scheduler.resubmitFailedStages() + + // the scheduler now creates a new task set to regenerate the missing map output, but this time + // using a different stage, the "skipped" one + + // SPARK-9809 -- this stage is submitted without a task for each partition (because some of + // the shuffle map output is still available from stage 0); make sure we've still got internal + // accumulators setup + assert(scheduler.stageIdToStage(2).internalAccumulators.nonEmpty) + completeShuffleMapStageSuccessfully(2, 0, 2) + completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) + assert(results === Map(0 -> 1234, 1 -> 1235)) + + assertDataStructuresEmpty() + } + /** * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we * have completions from both the first & second attempt of stage 1. So all the map output is diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index afe2e80358ca0..e428414cf6e85 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -322,6 +322,12 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val conf = new SparkConf(false) conf.set("spark.kryo.registrationRequired", "true") + // these cases require knowing the internals of RoaringBitmap a little. Blocks span 2^16 + // values, and they use a bitmap (dense) if they have more than 4096 values, and an + // array (sparse) if they use less. So we just create two cases, one sparse and one dense. + // and we use a roaring bitmap for the empty blocks, so we trigger the dense case w/ mostly + // empty blocks + val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index b92a302806f76..d3b1b2b620b4d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -68,6 +68,17 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf))) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + doAnswer(new Answer[Void] { + def answer(invocationOnMock: InvocationOnMock): Void = { + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + outputFile.delete + tmp.renameTo(outputFile) + } + null + } + }).when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(blockManager.getDiskWriter( any[BlockId], @@ -84,7 +95,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte args(3).asInstanceOf[Int], compressStream = identity, syncWrites = false, - args(4).asInstanceOf[ShuffleWriteMetrics] + args(4).asInstanceOf[ShuffleWriteMetrics], + blockId = args(0).asInstanceOf[BlockId] ) } }) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 18eec7da9763e..ceecfd665bf87 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -615,29 +615,29 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) assert(stage0.contains("{\n label="parallelize";\n " + - "0 [label="ParallelCollectionRDD [0]"];\n }")) + "0 [label="ParallelCollectionRDD [0]")) assert(stage0.contains("{\n label="map";\n " + - "1 [label="MapPartitionsRDD [1]"];\n }")) + "1 [label="MapPartitionsRDD [1]")) assert(stage0.contains("{\n label="groupBy";\n " + - "2 [label="MapPartitionsRDD [2]"];\n }")) + "2 [label="MapPartitionsRDD [2]")) val stage1 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) assert(stage1.contains("{\n label="groupBy";\n " + - "3 [label="ShuffledRDD [3]"];\n }")) + "3 [label="ShuffledRDD [3]")) assert(stage1.contains("{\n label="map";\n " + - "4 [label="MapPartitionsRDD [4]"];\n }")) + "4 [label="MapPartitionsRDD [4]")) assert(stage1.contains("{\n label="groupBy";\n " + - "5 [label="MapPartitionsRDD [5]"];\n }")) + "5 [label="MapPartitionsRDD [5]")) val stage2 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) assert(stage2.contains("{\n label="groupBy";\n " + - "6 [label="ShuffledRDD [6]"];\n }")) + "6 [label="ShuffledRDD [6]")) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 953456c2caa89..3f94ef7041914 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -111,6 +111,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("Dependent Classes") { val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) + testCallsite(CallSite("happy", "birthday")) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( @@ -163,6 +164,10 @@ class JsonProtocolSuite extends SparkFunSuite { testBlockId(StreamBlockId(1, 2L)) } + /* ============================== * + | Backward compatibility tests | + * ============================== */ + test("ExceptionFailure backward compatibility") { val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None, None) @@ -334,14 +339,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) } - test("RDDInfo backward compatibility (scope, parent IDs)") { - // Prior to Spark 1.4.0, RDDInfo did not have the "Scope" and "Parent IDs" properties - val rddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), Some(new RDDOperationScope("fable"))) + test("RDDInfo backward compatibility (scope, parent IDs, callsite)") { + // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 + // "Callsite" was introduced in Spark 1.6.0 + val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), + CallSite("short", "long"), Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) - val expectedRddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq.empty, scope = None) + .removeField({ _._1 == "Callsite"}) + val expectedRddInfo = new RDDInfo( + 1, "one", 100, StorageLevel.NONE, Seq.empty, CallSite.empty, scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -389,6 +397,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } + private def testCallsite(callsite: CallSite): Unit = { + val newCallsite = JsonProtocol.callsiteFromJson(JsonProtocol.callsiteToJson(callsite)) + assert(callsite === newCallsite) + } + private def testStageInfo(info: StageInfo) { val newInfo = JsonProtocol.stageInfoFromJson(JsonProtocol.stageInfoToJson(info)) assertEquals(info, newInfo) @@ -713,7 +726,8 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, + Seq(1, 4, 7), CallSite(a.toString, b.toString)) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -856,6 +870,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", + | "Callsite": {"Short Form": "101", "Long Form": "201"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1258,6 +1273,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", + | "Callsite": {"Short Form": "1", "Long Form": "200"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1301,6 +1317,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", + | "Callsite": {"Short Form": "2", "Long Form": "400"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1318,6 +1335,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": {"Short Form": "3", "Long Form": "401"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1361,6 +1379,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": {"Short Form": "3", "Long Form": "600"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1378,6 +1397,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": {"Short Form": "4", "Long Form": "601"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1395,6 +1415,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": {"Short Form": "5", "Long Form": "602"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1438,6 +1459,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": {"Short Form": "4", "Long Form": "800"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1455,6 +1477,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": {"Short Form": "5", "Long Form": "801"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1472,6 +1495,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 6, | "Name": "mayor", + | "Callsite": {"Short Form": "6", "Long Form": "802"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1489,6 +1513,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 7, | "Name": "mayor", + | "Callsite": {"Short Form": "7", "Long Form": "803"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala index b0db0988eeaab..69dbfa9cd7141 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -17,10 +17,7 @@ package org.apache.spark.util.collection -import java.io.{File, FileInputStream, FileOutputStream, ObjectInputStream, ObjectOutputStream} - import org.apache.spark.SparkFunSuite -import org.apache.spark.util.{Utils => UUtils} class BitSetSuite extends SparkFunSuite { @@ -155,50 +152,4 @@ class BitSetSuite extends SparkFunSuite { assert(bitsetDiff.nextSetBit(85) === 85) assert(bitsetDiff.nextSetBit(86) === -1) } - - test("read and write externally") { - val tempDir = UUtils.createTempDir() - val outputFile = File.createTempFile("bits", null, tempDir) - - val fos = new FileOutputStream(outputFile) - val oos = new ObjectOutputStream(fos) - - // Create BitSet - val setBits = Seq(0, 9, 1, 10, 90, 96) - val bitset = new BitSet(100) - - for (i <- 0 until 100) { - assert(!bitset.get(i)) - } - - setBits.foreach(i => bitset.set(i)) - - for (i <- 0 until 100) { - if (setBits.contains(i)) { - assert(bitset.get(i)) - } else { - assert(!bitset.get(i)) - } - } - assert(bitset.cardinality() === setBits.size) - - bitset.writeExternal(oos) - oos.close() - - val fis = new FileInputStream(outputFile) - val ois = new ObjectInputStream(fis) - - // Read BitSet from the file - val bitset2 = new BitSet(0) - bitset2.readExternal(ois) - - for (i <- 0 until 100) { - if (setBits.contains(i)) { - assert(bitset2.get(i)) - } else { - assert(!bitset2.get(i)) - } - } - assert(bitset2.cardinality() === setBits.size) - } } diff --git a/dev/mima b/dev/mima index 2952fa65d42ff..d5baffc6ef8a3 100755 --- a/dev/mima +++ b/dev/mima @@ -38,7 +38,7 @@ generate_mima_ignore() { # it did not process the new classes (which are in assembly jar). generate_mima_ignore -export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" +export SPARK_CLASSPATH="$(build/sbt "export oldDeps/fullClasspath" | tail -n1)" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" generate_mima_ignore diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml new file mode 100644 index 0000000000000..dee0c4aa37ae8 --- /dev/null +++ b/docker-integration-tests/pom.xml @@ -0,0 +1,149 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../pom.xml + + + spark-docker-integration-tests_2.10 + jar + Spark Project Docker Integration Tests + http://spark.apache.org/ + + docker-integration-tests + + + + + com.spotify + docker-client + shaded + test + + + + com.fasterxml.jackson.jaxrs + jackson-jaxrs-json-provider + + + com.fasterxml.jackson.datatype + jackson-datatype-guava + + + com.fasterxml.jackson.core + jackson-databind + + + org.glassfish.jersey.core + jersey-client + + + org.glassfish.jersey.connectors + jersey-apache-connector + + + org.glassfish.jersey.media + jersey-media-json-jackson + + + + + + com.google.guava + guava + 18.0 + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + ${project.version} + test + + + + com.sun.jersey + jersey-server + 1.19 + test + + + com.sun.jersey + jersey-core + 1.19 + test + + + com.sun.jersey + jersey-servlet + 1.19 + test + + + com.sun.jersey + jersey-json + 1.19 + test + + + stax + stax-api + + + + + + diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala new file mode 100644 index 0000000000000..c503c4a13b482 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.net.ServerSocket +import java.sql.Connection + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.spotify.docker.client._ +import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.DockerUtils +import org.apache.spark.sql.test.SharedSQLContext + +abstract class DatabaseOnDocker { + /** + * The docker image to be pulled. + */ + val imageName: String + + /** + * Environment variables to set inside of the Docker container while launching it. + */ + val env: Map[String, String] + + /** + * The container-internal JDBC port that the database listens on. + */ + val jdbcPort: Int + + /** + * Return a JDBC URL that connects to the database running at the given IP address and port. + */ + def getJdbcUrl(ip: String, port: Int): String +} + +abstract class DockerJDBCIntegrationSuite + extends SparkFunSuite + with BeforeAndAfterAll + with Eventually + with SharedSQLContext { + + val db: DatabaseOnDocker + + private var docker: DockerClient = _ + private var containerId: String = _ + protected var jdbcUrl: String = _ + + override def beforeAll() { + super.beforeAll() + try { + docker = DefaultDockerClient.fromEnv.build() + // Check that Docker is actually up + try { + docker.ping() + } catch { + case NonFatal(e) => + log.error("Exception while connecting to Docker. Check whether Docker is running.") + throw e + } + // Ensure that the Docker image is installed: + try { + docker.inspectImage(db.imageName) + } catch { + case e: ImageNotFoundException => + log.warn(s"Docker image ${db.imageName} not found; pulling image from registry") + docker.pull(db.imageName) + } + // Configure networking (necessary for boot2docker / Docker Machine) + val externalPort: Int = { + val sock = new ServerSocket(0) + val port = sock.getLocalPort + sock.close() + port + } + val dockerIp = DockerUtils.getDockerIp() + val hostConfig: HostConfig = HostConfig.builder() + .networkMode("bridge") + .portBindings( + Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) + .build() + // Create the database container: + val config = ContainerConfig.builder() + .image(db.imageName) + .networkDisabled(false) + .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) + .hostConfig(hostConfig) + .exposedPorts(s"${db.jdbcPort}/tcp") + .build() + containerId = docker.createContainer(config).id + // Start the container and wait until the database can accept JDBC connections: + docker.startContainer(containerId) + jdbcUrl = db.getJdbcUrl(dockerIp, externalPort) + eventually(timeout(60.seconds), interval(1.seconds)) { + val conn = java.sql.DriverManager.getConnection(jdbcUrl) + conn.close() + } + // Run any setup queries: + val conn: Connection = java.sql.DriverManager.getConnection(jdbcUrl) + try { + dataPreparation(conn) + } finally { + conn.close() + } + } catch { + case NonFatal(e) => + try { + afterAll() + } finally { + throw e + } + } + } + + override def afterAll() { + try { + if (docker != null) { + try { + if (containerId != null) { + docker.killContainer(containerId) + docker.removeContainer(containerId) + } + } catch { + case NonFatal(e) => + logWarning(s"Could not stop container $containerId", e) + } finally { + docker.close() + } + } + } finally { + super.afterAll() + } + } + + /** + * Prepare databases and tables for testing. + */ + def dataPreparation(connection: Connection): Unit +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala new file mode 100644 index 0000000000000..c68e4dc4933b1 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "mysql:5.7.9" + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val jdbcPort: Int = 3306 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala new file mode 100644 index 0000000000000..164a7f396280c --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "postgres:9.4.5" + override val env = Map( + "POSTGRES_PASSWORD" -> "rootpass" + ) + override val jdbcPort = 5432 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " + + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() + } + + test("Type mapping for various types") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 10) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Double")) + assert(types(3).equals("class java.lang.Long")) + assert(types(4).equals("class java.lang.Boolean")) + assert(types(5).equals("class [B")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class java.lang.Boolean")) + assert(types(8).equals("class java.lang.String")) + assert(types(9).equals("class java.lang.String")) + assert(rows(0).getString(0).equals("hello")) + assert(rows(0).getInt(1) == 42) + assert(rows(0).getDouble(2) == 1.25) + assert(rows(0).getLong(3) == 123456789012345L) + assert(rows(0).getBoolean(4) == false) + // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), + Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), + Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(rows(0).getBoolean(7) == true) + assert(rows(0).getString(8) == "172.16.0.42") + assert(rows(0).getString(9) == "192.168.0.0/16") + } + + test("Basic write test") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test only that it doesn't crash. + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala new file mode 100644 index 0000000000000..87271776d8564 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{Inet4Address, NetworkInterface, InetAddress} + +import scala.collection.JavaConverters._ +import scala.sys.process._ +import scala.util.Try + +private[spark] object DockerUtils { + + def getDockerIp(): String = { + /** If docker-machine is setup on this box, attempts to find the ip from it. */ + def findFromDockerMachine(): Option[String] = { + sys.env.get("DOCKER_MACHINE_NAME").flatMap { name => + Try(Seq("/bin/bash", "-c", s"docker-machine ip $name 2>/dev/null").!!.trim).toOption + } + } + sys.env.get("DOCKER_IP") + .orElse(findFromDockerMachine()) + .orElse(Try(Seq("/bin/bash", "-c", "boot2docker ip 2>/dev/null").!!.trim).toOption) + .getOrElse { + // This block of code is based on Utils.findLocalInetAddress(), but is modified to blacklist + // certain interfaces. + val address = InetAddress.getLocalHost + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order + // on unix-like system. On windows, it returns in index order. + // It's more proper to pick ip address following system output order. + val blackListedIFs = Seq( + "vboxnet0", // Mac + "docker0" // Linux + ) + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq.filter { i => + !blackListedIFs.contains(i.getName) + } + val reOrderedNetworkIFs = activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) + return strippedAddress.getHostAddress + } + } + address.getHostAddress + } + } +} diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index 6ee63a5ac69df..549f81fe1b1bc 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -28,7 +28,7 @@ def initialize(tag_name, markup, tokens) def render(context) site = context.registers[:site] - config_dir = (site.config['code_dir'] || '../examples/src/main').sub(/^\//,'') + config_dir = '../examples/src/main' @code_dir = File.join(site.source, config_dir) clean_markup = @markup.strip @@ -38,7 +38,12 @@ def render(context) code = File.open(@file).read.encode("UTF-8") code = select_lines(code) - Pygments.highlight(code, :lexer => @lang) + rendered_code = Pygments.highlight(code, :lexer => @lang) + + hint = "
    Find full example code at " \ + "\"examples/src/main/#{clean_markup}\" in the Spark repo.
    " + + rendered_code + hint end # Trim the code block so as to have the same indention, regardless of their positions in the diff --git a/docs/building-spark.md b/docs/building-spark.md index 4f73adb85446c..3d38edbdad4bc 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -190,6 +190,10 @@ Running only Java 8 tests and nothing else. mvn install -DskipTests -Pjava8-tests +or + + sbt -Pjava8-tests java8-tests/test + Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`. For these tests to run your system must have a JDK 8 installation. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. diff --git a/docs/configuration.md b/docs/configuration.md index c276e8e90decf..d961f43acf4ab 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are daily, hourly, minutely or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - sc.dump_profiles(path). If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by - passing a profiler class in as a parameter to the SparkContext constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 8d9c2ba2041b2..a3c34cb6796fa 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -56,36 +56,32 @@ provide another approach to share RDDs. ## Dynamic Resource Allocation -Spark 1.2 introduces the ability to dynamically scale the set of cluster resources allocated to -your application up and down based on the workload. This means that your application may give -resources back to the cluster if they are no longer used and request them again later when there -is demand. This feature is particularly useful if multiple applications share resources in your -Spark cluster. If a subset of the resources allocated to an application becomes idle, it can be -returned to the cluster's pool of resources and acquired by other applications. In Spark, dynamic -resource allocation is performed on the granularity of the executor and can be enabled through -`spark.dynamicAllocation.enabled`. - -This feature is currently disabled by default and available only on [YARN](running-on-yarn.html). -A future release will extend this to [standalone mode](spark-standalone.html) and -[Mesos coarse-grained mode](running-on-mesos.html#mesos-run-modes). Note that although Spark on -Mesos already has a similar notion of dynamic resource sharing in fine-grained mode, enabling -dynamic allocation allows your Mesos application to take advantage of coarse-grained low-latency -scheduling while sharing cluster resources efficiently. +Spark provides a mechanism to dynamically adjust the resources your application occupies based +on the workload. This means that your application may give resources back to the cluster if they +are no longer used and request them again later when there is demand. This feature is particularly +useful if multiple applications share resources in your Spark cluster. + +This feature is disabled by default and available on all coarse-grained cluster managers, i.e. +[standalone mode](spark-standalone.html), [YARN mode](running-on-yarn.html), and +[Mesos coarse-grained mode](running-on-mesos.html#mesos-run-modes). ### Configuration and Setup -All configurations used by this feature live under the `spark.dynamicAllocation.*` namespace. -To enable this feature, your application must set `spark.dynamicAllocation.enabled` to `true`. -Other relevant configurations are described on the -[configurations page](configuration.html#dynamic-allocation) and in the subsequent sections in -detail. +There are two requirements for using this feature. First, your application must set +`spark.dynamicAllocation.enabled` to `true`. Second, you must set up an *external shuffle service* +on each worker node in the same cluster and set `spark.shuffle.service.enabled` to true in your +application. The purpose of the external shuffle service is to allow executors to be removed +without deleting shuffle files written by them (more detail described +[below](job-scheduling.html#graceful-decommission-of-executors)). The way to set up this service +varies across cluster managers: + +In standalone mode, simply start your workers with `spark.shuffle.service.enabled` set to `true`. -Additionally, your application must use an external shuffle service. The purpose of the service is -to preserve the shuffle files written by executors so the executors can be safely removed (more -detail described [below](job-scheduling.html#graceful-decommission-of-executors)). To enable -this service, set `spark.shuffle.service.enabled` to `true`. In YARN, this external shuffle service -is implemented in `org.apache.spark.yarn.network.YarnShuffleService` that runs in each `NodeManager` -in your cluster. To start this service, follow these steps: +In Mesos coarse-grained mode, run `$SPARK_HOME/sbin/start-mesos-shuffle-service.sh` on all +slave nodes with `spark.shuffle.service.enabled` set to `true`. For instance, you may do so +through Marathon. + +In YARN mode, start the shuffle service on each `NodeManager` as follows: 1. Build Spark with the [YARN profile](building-spark.html). Skip this step if you are using a pre-packaged distribution. @@ -95,10 +91,13 @@ pre-packaged distribution. 2. Add this jar to the classpath of all `NodeManager`s in your cluster. 3. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to -`org.apache.spark.network.yarn.YarnShuffleService`. Additionally, set all relevant -`spark.shuffle.service.*` [configurations](configuration.html). +`org.apache.spark.network.yarn.YarnShuffleService` and `spark.shuffle.service.enabled` to true. 4. Restart all `NodeManager`s in your cluster. +All other relevant configurations are optional and under the `spark.dynamicAllocation.*` and +`spark.shuffle.service.*` namespaces. For more detail, see the +[configurations page](configuration.html#dynamic-allocation). + ### Resource Allocation Policy At a high level, Spark should relinquish executors when they are no longer used and acquire diff --git a/docs/ml-ann.md b/docs/ml-ann.md index d5ddd92af1e96..6e763e8f41568 100644 --- a/docs/ml-ann.md +++ b/docs/ml-ann.md @@ -48,76 +48,15 @@ MLPC employes backpropagation for learning the model. We use logistic loss funct
    - -{% highlight scala %} -import org.apache.spark.ml.classification.MultilayerPerceptronClassifier -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.Row - -// Load training data -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt").toDF() -// Split the data into train and test -val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) -val train = splits(0) -val test = splits(1) -// specify layers for the neural network: -// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) -val layers = Array[Int](4, 5, 4, 3) -// create the trainer and set its parameters -val trainer = new MultilayerPerceptronClassifier() - .setLayers(layers) - .setBlockSize(128) - .setSeed(1234L) - .setMaxIter(100) -// train the model -val model = trainer.fit(train) -// compute precision on the test set -val result = model.transform(test) -val predictionAndLabels = result.select("prediction", "label") -val evaluator = new MulticlassClassificationEvaluator() - .setMetricName("precision") -println("Precision:" + evaluator.evaluate(predictionAndLabels)) -{% endhighlight %} - +{% include_example scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala %}
    +{% include_example java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java %} +
    -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; -import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; - -// Load training data -String path = "data/mllib/sample_multiclass_classification_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); -DataFrame dataFrame = sqlContext.createDataFrame(data, LabeledPoint.class); -// Split the data into train and test -DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); -DataFrame train = splits[0]; -DataFrame test = splits[1]; -// specify layers for the neural network: -// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) -int[] layers = new int[] {4, 5, 4, 3}; -// create the trainer and set its parameters -MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() - .setLayers(layers) - .setBlockSize(128) - .setSeed(1234L) - .setMaxIter(100); -// train the model -MultilayerPerceptronClassificationModel model = trainer.fit(train); -// compute precision on the test set -DataFrame result = model.transform(test); -DataFrame predictionAndLabels = result.select("prediction", "label"); -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setMetricName("precision"); -System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); -{% endhighlight %} +
    +{% include_example python/ml/multilayer_perceptron_classification.py %}
    diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md index 542819e93e6dc..2bfac6f6c8378 100644 --- a/docs/ml-decision-tree.md +++ b/docs/ml-decision-tree.md @@ -118,196 +118,24 @@ We use two feature transformers to prepare the data; these help index categories More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.DecisionTreeClassifier -import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] -println("Learned classification tree model:\n" + treeModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} +
    More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.DecisionTreeClassifier; -import org.apache.spark.ml.classification.DecisionTreeClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeClassifier dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures"); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, dt, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -DecisionTreeClassificationModel treeModel = - (DecisionTreeClassificationModel)(model.stages()[2]); -System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} +
    More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import DecisionTreeClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") - -# Chain indexers and tree in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) +{% include_example python/ml/decision_tree_classification_example.py %} -treeModel = model.stages[2] -print treeModel # summary only -{% endhighlight %}
    @@ -323,171 +151,21 @@ We use a feature transformer to index categorical features, adding metadata to t More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.DecisionTreeRegressor -import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Automatically identify categorical features, and index them. -// Here, we treat features with > 4 distinct values as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - -// Chain indexer and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, dt)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] -println("Learned regression tree model:\n" + treeModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %}
    More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.DecisionTreeRegressionModel; -import org.apache.spark.ml.regression.DecisionTreeRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setFeaturesCol("indexedFeatures"); - -// Chain indexer and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, dt}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -DecisionTreeRegressionModel treeModel = - (DecisionTreeRegressionModel)(model.stages()[1]); -System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %}
    More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import DecisionTreeRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeRegressor(featuresCol="indexedFeatures") - -# Chain indexer and tree in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, dt]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -treeModel = model.stages[1] -print treeModel # summary only -{% endhighlight %} +{% include_example python/ml/decision_tree_regression_example.py %}
    diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 58f566c9b4b55..ce15f5e6466ec 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -195,7 +195,7 @@ import org.apache.spark.ml.feature.*; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read.format("libsvm") +DataFrame data = sqlContext.read().format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. @@ -384,7 +384,7 @@ import org.apache.spark.ml.regression.RandomForestRegressor; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read.format("libsvm") +DataFrame data = sqlContext.read().format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. @@ -640,7 +640,7 @@ import org.apache.spark.ml.feature.*; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -DataFrame data sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); +DataFrame data sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -830,7 +830,7 @@ import org.apache.spark.ml.regression.GBTRegressor; import org.apache.spark.sql.DataFrame; // Load and parse the data file, converting it to a DataFrame. -DataFrame data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt"); +DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -1000,7 +1000,7 @@ SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); -DataFrame dataFrame = sqlContext.read.format("libsvm") +DataFrame dataFrame = sqlContext.read().format("libsvm") .load("data/mllib/sample_multiclass_classification_data.txt"); DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345); diff --git a/docs/ml-features.md b/docs/ml-features.md index 142afac2f3f95..cd1838d6d2882 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1109,7 +1109,7 @@ import org.apache.spark.ml.feature.VectorIndexer; import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.sql.DataFrame; -DataFrame data = sqlContext.read.format("libsvm") +DataFrame data = sqlContext.read().format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") @@ -1187,7 +1187,7 @@ for more details on the API. import org.apache.spark.ml.feature.Normalizer; import org.apache.spark.sql.DataFrame; -DataFrame dataFrame = sqlContext.read.format("libsvm") +DataFrame dataFrame = sqlContext.read().format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); // Normalize each Vector using $L^1$ norm. @@ -1273,7 +1273,7 @@ import org.apache.spark.ml.feature.StandardScaler; import org.apache.spark.ml.feature.StandardScalerModel; import org.apache.spark.sql.DataFrame; -DataFrame dataFrame = sqlContext.read.format("libsvm") +DataFrame dataFrame = sqlContext.read().format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); StandardScaler scaler = new StandardScaler() .setInputCol("features") @@ -1366,7 +1366,7 @@ import org.apache.spark.ml.feature.MinMaxScaler; import org.apache.spark.ml.feature.MinMaxScalerModel; import org.apache.spark.sql.DataFrame; -DataFrame dataFrame = sqlContext.read.format("libsvm") +DataFrame dataFrame = sqlContext.read().format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); MinMaxScaler scaler = new MinMaxScaler() .setInputCol("features") diff --git a/docs/ml-guide.md b/docs/ml-guide.md index fd3a6167bc65e..be18a05361a17 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -44,6 +44,7 @@ provide class probabilities, and linear models provide model summaries. * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) * [Multilayer perceptron classifier](ml-ann.html) +* [Survival Regression](ml-survival-regression.html) # Main concepts in Pipelines @@ -866,10 +867,9 @@ The `ParamMap` which produces the best evaluation metric is selected as the best import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} -import org.apache.spark.mllib.util.MLUtils // Prepare training and test data. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) val lr = new LinearRegression() @@ -910,14 +910,9 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.tuning.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; -DataFrame data = sqlContext.createDataFrame( - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), - LabeledPoint.class); +DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Prepare training and test data. DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 16e2ee71293ae..85edfd373465f 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -95,7 +95,7 @@ public class LogisticRegressionWithElasticNetExample { String path = "data/mllib/sample_libsvm_data.txt"; // Load training data - DataFrame training = sqlContext.read.format("libsvm").load(path); + DataFrame training = sqlContext.read().format("libsvm").load(path); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) @@ -292,7 +292,7 @@ public class LinearRegressionWithElasticNetExample { String path = "data/mllib/sample_libsvm_data.txt"; // Load training data - DataFrame training = sqlContext.read.format("libsvm").load(path); + DataFrame training = sqlContext.read().format("libsvm").load(path); LinearRegression lr = new LinearRegression() .setMaxIter(10) diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md new file mode 100644 index 0000000000000..ab275213b9a84 --- /dev/null +++ b/docs/ml-survival-regression.md @@ -0,0 +1,96 @@ +--- +layout: global +title: Survival Regression - ML +displayTitle: ML - Survival Regression +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +## Example: + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
    + +
    +{% include_example python/ml/aft_survival_regression.py %} +
    + +
    \ No newline at end of file diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 1ad52123c74aa..7cd1b894e7cb5 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -66,43 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction. Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.recommendation.ALS -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel -import org.apache.spark.mllib.recommendation.Rating - -// Load and parse the data -val data = sc.textFile("data/mllib/als/test.data") -val ratings = data.map(_.split(',') match { case Array(user, item, rate) => - Rating(user.toInt, item.toInt, rate.toDouble) - }) - -// Build the recommendation model using ALS -val rank = 10 -val numIterations = 10 -val model = ALS.train(ratings, rank, numIterations, 0.01) - -// Evaluate the model on rating data -val usersProducts = ratings.map { case Rating(user, product, rate) => - (user, product) -} -val predictions = - model.predict(usersProducts).map { case Rating(user, product, rate) => - ((user, product), rate) - } -val ratesAndPreds = ratings.map { case Rating(user, product, rate) => - ((user, product), rate) -}.join(predictions) -val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => - val err = (r1 - r2) - err * err -}.mean() -println("Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %} If the rating matrix is derived from another source of information (e.g., it is inferred from other signals), you can use the `trainImplicit` method to get better results. @@ -123,81 +87,7 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.mllib.recommendation.Rating; -import org.apache.spark.SparkConf; - -public class CollaborativeFiltering { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Collaborative Filtering Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/als/test.data"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String s) { - String[] sarray = s.split(","); - return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), - Double.parseDouble(sarray[2])); - } - } - ); - - // Build the recommendation model using ALS - int rank = 10; - int numIterations = 10; - MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); - - // Evaluate the model on rating data - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - Double err = pair._1() - pair._2(); - return err * err; - } - } - ).rdd()).mean(); - System.out.println("Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
    @@ -207,29 +97,7 @@ recommendation by measuring the Mean Squared Error of rating prediction. Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating - -# Load and parse the data -data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) - -# Build the recommendation model using Alternating Least Squares -rank = 10 -numIterations = 10 -model = ALS.train(ratings, rank, numIterations) - -# Evaluate the model on training data -testdata = ratings.map(lambda p: (p[0], p[1])) -predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) -ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/recommendation_example.py %} If the rating matrix is derived from other source of information (i.e., it is inferred from other signals), you can use the trainImplicit method to get better results. diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index f31c4f88936bd..77ce34e91af3c 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -194,137 +194,19 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
    Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "gini" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala %}
    Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model for classification. -final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java %}
    Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - impurity='gini', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_classification_example.py %}
    @@ -343,142 +225,19 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
    Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "variance" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, - maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala %}
    Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model. -final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java %}
    Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, - impurity='variance', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_regression_example.py %}
    diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index fc587298f7d2e..50450e05d2abb 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -98,144 +98,19 @@ The test error is calculated to measure the algorithm accuracy.
    Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.RandomForest -import org.apache.spark.mllib.tree.model.RandomForestModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val numTrees = 3 // Use more in practice. -val featureSubsetStrategy = "auto" // Let the algorithm choose. -val impurity = "gini" -val maxDepth = 4 -val maxBins = 32 - -val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification forest model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala %}
    Refer to the [`RandomForest` Java docs](api/java/org/apache/spark/mllib/tree/RandomForest.html) and [`RandomForestModel` Java docs](api/java/org/apache/spark/mllib/tree/model/RandomForestModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassification"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -HashMap categoricalFeaturesInfo = new HashMap(); -Integer numTrees = 3; // Use more in practice. -String featureSubsetStrategy = "auto"; // Let the algorithm choose. -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; -Integer seed = 12345; - -final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, - seed); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification forest model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java %}
    Refer to the [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForest) and [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForestModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import RandomForest, RandomForestModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -# Note: Use larger numTrees in practice. -# Setting featureSubsetStrategy="auto" lets the algorithm choose. -model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='gini', maxDepth=4, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification forest model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/random_forest_classification_example.py %}
    @@ -254,147 +129,19 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.RandomForest -import org.apache.spark.mllib.tree.model.RandomForestModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a RandomForest model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val numTrees = 3 // Use more in practice. -val featureSubsetStrategy = "auto" // Let the algorithm choose. -val impurity = "variance" -val maxDepth = 4 -val maxBins = 32 - -val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, - numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression forest model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala %}
    Refer to the [`RandomForest` Java docs](api/java/org/apache/spark/mllib/tree/RandomForest.html) and [`RandomForestModel` Java docs](api/java/org/apache/spark/mllib/tree/model/RandomForestModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.RandomForest; -import org.apache.spark.mllib.tree.model.RandomForestModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForest"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 4; -Integer maxBins = 32; - -// Train a RandomForest model. -final RandomForestModel model = RandomForest.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression forest model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java %}
    Refer to the [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForest) and [`RandomForest` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.RandomForestModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import RandomForest, RandomForestModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a RandomForest model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -# Note: Use larger numTrees in practice. -# Setting featureSubsetStrategy="auto" lets the algorithm choose. -model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, - numTrees=3, featureSubsetStrategy="auto", - impurity='variance', maxDepth=4, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression forest model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = RandomForestModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/random_forest_regression_example.py %}
    @@ -492,141 +239,19 @@ The test error is calculated to measure the algorithm accuracy.
    Refer to the [`GradientBoostedTrees` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.GradientBoostedTreesModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a GradientBoostedTrees model. -// The defaultParams for Classification use LogLoss by default. -val boostingStrategy = BoostingStrategy.defaultParams("Classification") -boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.numClasses = 2 -boostingStrategy.treeStrategy.maxDepth = 5 -// Empty categoricalFeaturesInfo indicates all features are continuous. -boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() - -val model = GradientBoostedTrees.train(trainingData, boostingStrategy) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification GBT model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala %}
    Refer to the [`GradientBoostedTrees` Java docs](api/java/org/apache/spark/mllib/tree/GradientBoostedTrees.html) and [`GradientBoostedTreesModel` Java docs](api/java/org/apache/spark/mllib/tree/model/GradientBoostedTreesModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import java.util.Map; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a GradientBoostedTrees model. -// The defaultParams for Classification use LogLoss by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); -boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. -boostingStrategy.getTreeStrategy().setNumClassesForClassification(2); -boostingStrategy.getTreeStrategy().setMaxDepth(5); -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - -final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification GBT model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java %}
    Refer to the [`GradientBoostedTrees` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTreesModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GradientBoostedTrees model. -# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. -# (b) Use more iterations in practice. -model = GradientBoostedTrees.trainClassifier(trainingData, - categoricalFeaturesInfo={}, numIterations=3) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification GBT model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/gradient_boosting_classification_example.py %}
    @@ -645,146 +270,19 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    Refer to the [`GradientBoostedTrees` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.GradientBoostedTreesModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a GradientBoostedTrees model. -// The defaultParams for Regression use SquaredError by default. -val boostingStrategy = BoostingStrategy.defaultParams("Regression") -boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. -boostingStrategy.treeStrategy.maxDepth = 5 -// Empty categoricalFeaturesInfo indicates all features are continuous. -boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() - -val model = GradientBoostedTrees.train(trainingData, boostingStrategy) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression GBT model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala %}
    Refer to the [`GradientBoostedTrees` Java docs](api/java/org/apache/spark/mllib/tree/GradientBoostedTrees.html) and [`GradientBoostedTreesModel` Java docs](api/java/org/apache/spark/mllib/tree/model/GradientBoostedTreesModel.html) for details on the API. -{% highlight java %} -import scala.Tuple2; -import java.util.HashMap; -import java.util.Map; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoostedTrees; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; -import org.apache.spark.mllib.util.MLUtils; - -SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Train a GradientBoostedTrees model. -// The defaultParams for Regression use SquaredError by default. -BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); -boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. -boostingStrategy.getTreeStrategy().setMaxDepth(5); -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); - -final GradientBoostedTreesModel model = - GradientBoostedTrees.train(trainingData, boostingStrategy); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression GBT model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java %}
    Refer to the [`GradientBoostedTrees` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTrees) and [`GradientBoostedTreesModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.GradientBoostedTreesModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a GradientBoostedTrees model. -# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. -# (b) Use more iterations in practice. -model = GradientBoostedTrees.trainRegressor(trainingData, - categoricalFeaturesInfo={}, numIterations=3) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression GBT model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = GradientBoostedTreesModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/gradient_boosting_regression_example.py %}
    diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index a3bd130ba077c..ad7bcd9bfd407 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -220,154 +220,13 @@ L-BFGS optimizer.
    Refer to the [`LBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) and [`SquaredL2Updater` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.SquaredL2Updater) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val numFeatures = data.take(1)(0).features.size - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) - -// Append 1 into the training data as intercept. -val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() - -val test = splits(1) - -// Run training algorithm to build the model -val numCorrections = 10 -val convergenceTol = 1e-4 -val maxNumIterations = 20 -val regParam = 0.1 -val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) - -val (weightsWithIntercept, loss) = LBFGS.runLBFGS( - training, - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept) - -val model = new LogisticRegressionModel( - Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), - weightsWithIntercept(weightsWithIntercept.size - 1)) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Loss of each step in training process") -loss.foreach(println) -println("Area under ROC = " + auROC) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LBFGSExample.scala %}
    Refer to the [`LBFGS` Java docs](api/java/org/apache/spark/mllib/optimization/LBFGS.html) and [`SquaredL2Updater` Java docs](api/java/org/apache/spark/mllib/optimization/SquaredL2Updater.html) for details on the API. -{% highlight java %} -import java.util.Arrays; -import java.util.Random; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.optimization.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class LBFGSExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - int numFeatures = data.take(1).get(0).features().size(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD trainingInit = data.sample(false, 0.6, 11L); - JavaRDD test = data.subtract(trainingInit); - - // Append 1 into the training data as intercept. - JavaRDD> training = data.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - return new Tuple2(p.label(), MLUtils.appendBias(p.features())); - } - }); - training.cache(); - - // Run training algorithm to build the model. - int numCorrections = 10; - double convergenceTol = 1e-4; - int maxNumIterations = 20; - double regParam = 0.1; - Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); - - Tuple2 result = LBFGS.runLBFGS( - training.rdd(), - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept); - Vector weightsWithIntercept = result._1(); - double[] loss = result._2(); - - final LogisticRegressionModel model = new LogisticRegressionModel( - Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), - (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - }); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(scoreAndLabels.rdd()); - double auROC = metrics.areaUnderROC(); - - System.out.println("Loss of each step in training process"); - for (double l : loss) - System.out.println(l); - System.out.println("Area under ROC = " + auROC); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLBFGSExample.java %}
    diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 2c7c9ed693fd4..ade5b0768aefe 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -594,7 +594,7 @@ sc = ... # SparkContext # Generate a random double RDD that contains 1 million i.i.d. values drawn from the # standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +u = RandomRDDs.normalRDD(sc, 1000000L, 10) # Apply a transform to get a random double RDD following `N(1, 4)`. v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2fe5c36338899..6e02d6564b002 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -982,7 +982,8 @@ when a table is dropped. [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. +of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +compatibility reasons. ### Loading Data Programmatically @@ -1089,15 +1090,6 @@ for (teenName in collect(teenNames)) { -
    - -{% highlight python %} -# sqlContext is an existing HiveContext -sqlContext.sql("REFRESH TABLE my_table") -{% endhighlight %} - -
    -
    {% highlight sql %} @@ -1636,8 +1628,10 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do not have an existing Hive deployment can still create a `HiveContext`. When not configured by the -hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current -directory. +hive-site.xml, the context automatically creates `metastore_db` in the current directory and +creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. +Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts +the spark application. {% highlight scala %} // sc is an existing SparkContext. @@ -2294,7 +2288,7 @@ Several caching related features are not supported yet: Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). +(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index ab7f0117c0b7f..b00351b2fbcc0 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -181,7 +181,20 @@ Next, we discuss how to use this approach in your streaming application. );
    - Not supported yet + offsetRanges = [] + + def storeOffsetRanges(rdd): + global offsetRanges + offsetRanges = rdd.offsetRanges() + return rdd + + def printOffsetRanges(rdd): + for o in offsetRanges: + print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) + + directKafkaStream\ + .transform(storeOffsetRanges)\ + .foreachRDD(printOffsetRanges)
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c751dbb41785a..e9a27f446a898 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1948,8 +1948,8 @@ unifiedStream.print(); {% highlight python %} numStreams = 5 kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] -unifiedStream = streamingContext.union(kafkaStreams) -unifiedStream.print() +unifiedStream = streamingContext.union(*kafkaStreams) +unifiedStream.pprint() {% endhighlight %} diff --git a/docs/tuning.md b/docs/tuning.md index 6936912a6be54..879340a01544f 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -61,8 +61,8 @@ The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes mor registration options, such as adding custom serialization code. If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` -config property. The default is 2, but this value needs to be large enough to hold the *largest* -object you will serialize. +[config](configuration.html#compression-and-serialization). This value needs to be large enough +to hold the *largest* object you will serialize. Finally, if you don't register your custom classes, Kryo will still work, but it will have to store the full class name with each object, which is wasteful. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java new file mode 100644 index 0000000000000..69a174562fcf5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.regression.AFTSurvivalRegression; +import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaAFTSurvivalRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaAFTSurvivalRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)), + RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)), + RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)), + RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)), + RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + DataFrame training = jsql.createDataFrame(data, schema); + double[] quantileProbabilities = new double[]{0.3, 0.6}; + AFTSurvivalRegression aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles"); + + AFTSurvivalRegressionModel model = aft.fit(training); + + // Print the coefficients, intercept and scale parameter for AFT survival regression + System.out.println("Coefficients: " + model.coefficients() + " Intercept: " + + model.intercept() + " Scale: " + model.scale()); + model.transform(training).show(false); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..482225e585cf8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.DecisionTreeClassifier; +import org.apache.spark.ml.classification.DecisionTreeClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeClassificationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + + // Automatically identify categorical features, and index them. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + DecisionTreeClassificationModel treeModel = + (DecisionTreeClassificationModel) (model.stages()[2]); + System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..c7f1868dd105a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + DataFrame data = sqlContext.read().format("libsvm") + .load("data/mllib/sample_libsvm_data.txt"); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{featureIndexer, dt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + DecisionTreeRegressionModel treeModel = + (DecisionTreeRegressionModel) (model.stages()[1]); + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java new file mode 100644 index 0000000000000..84369f6681d04 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.sql.DataFrame; +// $example off$ + +/** + * An example for Multilayer Perceptron Classification. + */ +public class JavaMultilayerPerceptronClassifierExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaMultilayerPerceptronClassifierExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + // Load training data + String path = "data/mllib/sample_multiclass_classification_data.txt"; + DataFrame dataFrame = jsql.read().format("libsvm").load(path); + // Split the data into train and test + DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); + DataFrame train = splits[0]; + DataFrame test = splits[1]; + // specify layers for the neural network: + // input layer of size 4 (features), two intermediate of size 5 and 4 + // and output of size 3 (classes) + int[] layers = new int[] {4, 5, 4, 3}; + // create the trainer and set its parameters + MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100); + // train the model + MultilayerPerceptronClassificationModel model = trainer.fit(train); + // compute precision on the test set + DataFrame result = model.transform(test); + DataFrame predictionAndLabels = result.select("prediction", "label"); + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); + System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java index e7f2f6f615070..f0d92a56bee73 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -27,9 +27,7 @@ import org.apache.spark.ml.util.MetadataUtils; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructField; @@ -80,31 +78,30 @@ public static void main(String[] args) { OneVsRest ovr = new OneVsRest().setClassifier(classifier); String input = params.input; - RDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); - RDD train; - RDD test; + DataFrame inputData = jsql.read().format("libsvm").load(input); + DataFrame train; + DataFrame test; // compute the train/ test split: if testInput is not provided use part of input String testInput = params.testInput; if (testInput != null) { train = inputData; // compute the number of features in the training set. - int numFeatures = inputData.first().features().size(); - test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + int numFeatures = inputData.first().getAs(1).size(); + test = jsql.read().format("libsvm").option("numFeatures", + String.valueOf(numFeatures)).load(testInput); } else { double f = params.fracTest; - RDD[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + DataFrame[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); train = tmp[0]; test = tmp[1]; } // train the multiclass model - DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); - OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + OneVsRestModel ovrModel = ovr.fit(train.cache()); // score the model on test data - DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); - DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + DataFrame predictions = ovrModel.transform(test.cache()) .select("prediction", "label"); // obtain metrics diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java index 23f834ab4332b..d433905fc8012 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java @@ -23,8 +23,6 @@ import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.regression.LinearRegression; import org.apache.spark.ml.tuning.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -46,9 +44,7 @@ public static void main(String[] args) { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); - DataFrame data = jsql.createDataFrame( - MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), - LabeledPoint.class); + DataFrame data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Prepare training and test data. DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..5839b0cf8a8f8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeClassificationExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model for classification. + final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..ccde578249f7c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeRegressionExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "variance"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model. + final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java new file mode 100644 index 0000000000000..80faabd2325d0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingClassificationExample.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaGradientBoostingClassificationExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf() + .setAppName("JavaGradientBoostedTreesClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a GradientBoostedTrees model. + // The defaultParams for Classification use LogLoss by default. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification"); + boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. + boostingStrategy.getTreeStrategy().setNumClasses(2); + boostingStrategy.getTreeStrategy().setMaxDepth(5); + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); + + final GradientBoostedTreesModel model = + GradientBoostedTrees.train(trainingData, boostingStrategy); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification GBT model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel"); + GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), + "target/tmp/myGradientBoostingClassificationModel"); + // $example off$ + } + +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java new file mode 100644 index 0000000000000..216895b368202 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostingRegressionExample.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaGradientBoostingRegressionExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf() + .setAppName("JavaGradientBoostedTreesRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a GradientBoostedTrees model. + // The defaultParams for Regression use SquaredError by default. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression"); + boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice. + boostingStrategy.getTreeStrategy().setMaxDepth(5); + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo); + + final GradientBoostedTreesModel model = + GradientBoostedTrees.train(trainingData, boostingStrategy); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression GBT model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel"); + GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(), + "target/tmp/myGradientBoostingRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java new file mode 100644 index 0000000000000..355883f61bd64 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.optimization.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example off$ + +public class JavaLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + int numFeatures = data.take(1).get(0).features().size(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD trainingInit = data.sample(false, 0.6, 11L); + JavaRDD test = data.subtract(trainingInit); + + // Append 1 into the training data as intercept. + JavaRDD> training = data.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + return new Tuple2(p.label(), MLUtils.appendBias(p.features())); + } + }); + training.cache(); + + // Run training algorithm to build the model. + int numCorrections = 10; + double convergenceTol = 1e-4; + int maxNumIterations = 20; + double regParam = 0.1; + Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); + + Tuple2 result = LBFGS.runLBFGS( + training.rdd(), + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept); + Vector weightsWithIntercept = result._1(); + double[] loss = result._2(); + + final LogisticRegressionModel model = new LogisticRegressionModel( + Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), + (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + }); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(scoreAndLabels.rdd()); + double auROC = metrics.areaUnderROC(); + + System.out.println("Loss of each step in training process"); + for (double l : loss) + System.out.println(l); + System.out.println("Area under ROC = " + auROC); + // $example off$ + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java new file mode 100644 index 0000000000000..9219eef1ad2d6 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestClassificationExample.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.RandomForest; +import org.apache.spark.mllib.tree.model.RandomForestModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +public class JavaRandomForestClassificationExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + HashMap categoricalFeaturesInfo = new HashMap(); + Integer numTrees = 3; // Use more in practice. + String featureSubsetStrategy = "auto"; // Let the algorithm choose. + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + Integer seed = 12345; + + final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, + seed); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification forest model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel"); + RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), + "target/tmp/myRandomForestClassificationModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java new file mode 100644 index 0000000000000..4db926a4218ff --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRandomForestRegressionExample.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.RandomForest; +import org.apache.spark.mllib.tree.model.RandomForestModel; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRandomForestRegressionExample { + public static void main(String[] args) { + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + Integer numTrees = 3; // Use more in practice. + String featureSubsetStrategy = "auto"; // Let the algorithm choose. + String impurity = "variance"; + Integer maxDepth = 4; + Integer maxBins = 32; + Integer seed = 12345; + // Train a RandomForest model. + final RandomForestModel model = RandomForest.trainRegressor(trainingData, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / testData.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression forest model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel"); + RandomForestModel sameModel = RandomForestModel.load(jsc.sc(), + "target/tmp/myRandomForestRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java new file mode 100644 index 0000000000000..1065fde953b96 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRecommendationExample { + public static void main(String args[]) { + // $example on$ + SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/als/test.data"; + JavaRDD data = jsc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String s) { + String[] sarray = s.split(","); + return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), + Double.parseDouble(sarray[2])); + } + } + ); + + // Build the recommendation model using ALS + int rank = 10; + int numIterations = 10; + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); + + // Evaluate the model on rating data + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + Double err = pair._1() - pair._2(); + return err * err; + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myCollaborativeFilter"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(jsc.sc(), + "target/tmp/myCollaborativeFilter"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java index 16ae9a3319ee2..337f8ffb5bfb0 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java @@ -66,7 +66,7 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount"); - // Create the context with a 1 second batch size + // Create the context with 2 seconds batch size JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); int numThreads = Integer.parseInt(args[3]); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 99b63a2590ae2..c400e4237abe3 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -26,18 +26,15 @@ import com.google.common.base.Optional; import com.google.common.collect.Lists; -import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.*; /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -63,25 +60,12 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); - // Update the cumulative count function - final Function2, Optional, Optional> updateFunction = - new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - Integer newSum = state.or(0); - for (Integer value : values) { - newSum += value; - } - return Optional.of(newSum); - } - }; - // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); ssc.checkpoint("."); - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey @SuppressWarnings("unchecked") List> tuples = Arrays.asList(new Tuple2("hello", 1), new Tuple2("world", 1)); @@ -105,9 +89,22 @@ public Tuple2 call(String s) { } }); + // Update the cumulative count function + final Function4, State, Optional>> trackStateFunc = + new Function4, State, Optional>>() { + + @Override + public Optional> call(Time time, String word, Optional one, State state) { + int sum = one.or(0) + (state.exists() ? state.get() : 0); + Tuple2 output = new Tuple2(word, sum); + state.update(sum); + return Optional.of(output); + } + }; + // This will give a Dstream made of state (which is the cumulative count of the words) - JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, - new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD); + JavaTrackStateDStream> stateDstream = + wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py new file mode 100644 index 0000000000000..0ee01fd8258df --- /dev/null +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import AFTSurvivalRegression +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="AFTSurvivalRegressionExample") + sqlContext = SQLContext(sc) + + # $example on$ + training = sqlContext.createDataFrame([ + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"]) + quantileProbabilities = [0.3, 0.6] + aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities, + quantilesCol="quantiles") + + model = aft.fit(training) + + # Print the coefficients, intercept and scale parameter for AFT survival regression + print("Coefficients: " + str(model.coefficients)) + print("Intercept: " + str(model.intercept)) + print("Scale: " + str(model.scale)) + model.transform(training).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/decision_tree_classification_example.py b/examples/src/main/python/ml/decision_tree_classification_example.py new file mode 100644 index 0000000000000..8cda56dbb9bdf --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_classification_example.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +import sys + +# $example on$ +from pyspark import SparkContext, SQLContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import DecisionTreeClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load the data stored in LIBSVM format as a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and tree in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g " % (1.0 - accuracy)) + + treeModel = model.stages[2] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py new file mode 100644 index 0000000000000..439e398947499 --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_regression_example.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import DecisionTreeRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load the data stored in LIBSVM format as a DataFrame. + data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeRegressor(featuresCol="indexedFeatures") + + # Chain indexer and tree in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, dt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + treeModel = model.stages[1] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py index 6446f0fe5eeab..c3bf8aa2eb1e6 100644 --- a/examples/src/main/python/ml/gradient_boosted_trees.py +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -24,7 +24,6 @@ from pyspark.ml.feature import StringIndexer from pyspark.ml.regression import GBTRegressor from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics -from pyspark.mllib.util import MLUtils from pyspark.sql import Row, SQLContext """ @@ -70,8 +69,8 @@ def testRegression(train, test): sc = SparkContext(appName="PythonGBTExample") sqlContext = SQLContext(sc) - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + # Load the data stored in LIBSVM format as a DataFrame. + df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Map labels into an indexed column of labels in [0, numLabels) stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py index 55afe1b207fe0..4cd027fdfbe8a 100644 --- a/examples/src/main/python/ml/logistic_regression.py +++ b/examples/src/main/python/ml/logistic_regression.py @@ -23,7 +23,6 @@ from pyspark.ml.classification import LogisticRegression from pyspark.mllib.evaluation import MulticlassMetrics from pyspark.ml.feature import StringIndexer -from pyspark.mllib.util import MLUtils from pyspark.sql import SQLContext """ @@ -41,8 +40,8 @@ sc = SparkContext(appName="PythonLogisticRegressionExample") sqlContext = SQLContext(sc) - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + # Load the data stored in LIBSVM format as a DataFrame. + df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Map labels into an indexed column of labels in [0, numLabels) stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") diff --git a/examples/src/main/python/ml/multilayer_perceptron_classification.py b/examples/src/main/python/ml/multilayer_perceptron_classification.py new file mode 100644 index 0000000000000..f84588f547fff --- /dev/null +++ b/examples/src/main/python/ml/multilayer_perceptron_classification.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import MultilayerPerceptronClassifier +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="multilayer_perceptron_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + data = sqlContext.read.format("libsvm")\ + .load("data/mllib/sample_multiclass_classification_data.txt") + # Split the data into train and test + splits = data.randomSplit([0.6, 0.4], 1234) + train = splits[0] + test = splits[1] + # specify layers for the neural network: + # input layer of size 4 (features), two intermediate of size 5 and 4 + # and output of size 3 (classes) + layers = [4, 5, 4, 3] + # create the trainer and set its parameters + trainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234) + # train the model + model = trainer.fit(train) + # compute precision on the test set + result = model.transform(test) + predictionAndLabels = result.select("prediction", "label") + evaluator = MulticlassClassificationEvaluator(metricName="precision") + print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py index c7730e1bfacd9..dc6a778670193 100644 --- a/examples/src/main/python/ml/random_forest_example.py +++ b/examples/src/main/python/ml/random_forest_example.py @@ -74,8 +74,8 @@ def testRegression(train, test): sc = SparkContext(appName="PythonRandomForestExample") sqlContext = SQLContext(sc) - # Load and parse the data file into a dataframe. - df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + # Load the data stored in LIBSVM format as a DataFrame. + df = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Map labels into an indexed column of labels in [0, numLabels) stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") diff --git a/examples/src/main/python/mllib/decision_tree_classification_example.py b/examples/src/main/python/mllib/decision_tree_classification_example.py new file mode 100644 index 0000000000000..1b529768b6c62 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_classification_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeClassificationExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + impurity='gini', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_regression_example.py b/examples/src/main/python/mllib/decision_tree_regression_example.py new file mode 100644 index 0000000000000..cf518eac67e81 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_regression_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeRegressionExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, + impurity='variance', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/gradient_boosting_classification_example.py b/examples/src/main/python/mllib/gradient_boosting_classification_example.py new file mode 100644 index 0000000000000..a94ea0d582e59 --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosting_classification_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Trees Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonGradientBoostedTreesClassificationExample") + # $example on$ + # Load and parse the data file. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GradientBoostedTrees model. + # Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. + # (b) Use more iterations in practice. + model = GradientBoostedTrees.trainClassifier(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification GBT model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myGradientBoostingClassificationModel") + sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/gradient_boosting_regression_example.py b/examples/src/main/python/mllib/gradient_boosting_regression_example.py new file mode 100644 index 0000000000000..86040799dc1d9 --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosting_regression_example.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient Boosted Trees Regression Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonGradientBoostedTreesRegressionExample") + # $example on$ + # Load and parse the data file. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a GradientBoostedTrees model. + # Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous. + # (b) Use more iterations in practice. + model = GradientBoostedTrees.trainRegressor(trainingData, + categoricalFeaturesInfo={}, numIterations=3) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression GBT model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myGradientBoostingRegressionModel") + sameModel = GradientBoostedTreesModel.load(sc, "target/tmp/myGradientBoostingRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py index a2e7dacf25491..f5e120c678fcf 100644 --- a/examples/src/main/python/mllib/naive_bayes_example.py +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -20,6 +20,7 @@ """ from __future__ import print_function +from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors diff --git a/examples/src/main/python/mllib/random_forest_classification_example.py b/examples/src/main/python/mllib/random_forest_classification_example.py new file mode 100644 index 0000000000000..324ba50625d25 --- /dev/null +++ b/examples/src/main/python/mllib/random_forest_classification_example.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import RandomForest, RandomForestModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonRandomForestClassificationExample") + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + # Note: Use larger numTrees in practice. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + model = RandomForest.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + numTrees=3, featureSubsetStrategy="auto", + impurity='gini', maxDepth=4, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification forest model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myRandomForestClassificationModel") + sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/random_forest_regression_example.py b/examples/src/main/python/mllib/random_forest_regression_example.py new file mode 100644 index 0000000000000..f7aa6114eceb3 --- /dev/null +++ b/examples/src/main/python/mllib/random_forest_regression_example.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Random Forest Regression Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import RandomForest, RandomForestModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonRandomForestRegressionExample") + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a RandomForest model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + # Note: Use larger numTrees in practice. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo={}, + numTrees=3, featureSubsetStrategy="auto", + impurity='variance', maxDepth=4, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression forest model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myRandomForestRegressionModel") + sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/recommendation_example.py b/examples/src/main/python/mllib/recommendation_example.py new file mode 100644 index 0000000000000..615db0749b182 --- /dev/null +++ b/examples/src/main/python/mllib/recommendation_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Collaborative Filtering Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext + +# $example on$ +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonCollaborativeFilteringExample") + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/als/test.data") + ratings = data.map(lambda l: l.split(','))\ + .map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) + + # Build the recommendation model using Alternating Least Squares + rank = 10 + numIterations = 10 + model = ALS.train(ratings, rank, numIterations) + + # Evaluate the model on training data + testdata = ratings.map(lambda p: (p[0], p[1])) + predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) + ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) + MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala new file mode 100644 index 0000000000000..5da285e83681f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.regression.AFTSurvivalRegression +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +/** + * An example for AFTSurvivalRegression. + */ +object AFTSurvivalRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("AFTSurvivalRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val training = sqlContext.createDataFrame(Seq( + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226)) + )).toDF("label", "censor", "features") + val quantileProbabilities = Array(0.3, 0.6) + val aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + + val model = aft.fit(training) + + // Print the coefficients, intercept and scale parameter for AFT survival regression + println(s"Coefficients: ${model.coefficients} Intercept: " + + s"${model.intercept} Scale: ${model.scale}") + model.transform(training).show(false) + // $example off$ + + sc.stop() + } +} +// scalastyle:off println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..ff8a0a90f1e44 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ + +object DecisionTreeClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] + println("Learned classification tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index f28671f7869fc..c4e98dfaca6c9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -32,10 +32,7 @@ import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTree import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.StringType import org.apache.spark.sql.{SQLContext, DataFrame} @@ -138,15 +135,18 @@ object DecisionTreeExample { /** Load a dataset from the given path, using the given format */ private[ml] def loadData( - sc: SparkContext, + sqlContext: SQLContext, path: String, format: String, - expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + expectedNumFeatures: Option[Int] = None): DataFrame = { + import sqlContext.implicits._ + format match { - case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "dense" => MLUtils.loadLabeledPoints(sqlContext.sparkContext, path).toDF() case "libsvm" => expectedNumFeatures match { - case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) - case None => MLUtils.loadLibSVMFile(sc, path) + case Some(numFeatures) => sqlContext.read.option("numFeatures", numFeatures.toString) + .format("libsvm").load(path) + case None => sqlContext.read.format("libsvm").load(path) } case _ => throw new IllegalArgumentException(s"Bad data format: $format") } @@ -169,36 +169,22 @@ object DecisionTreeExample { algo: String, fracTest: Double): (DataFrame, DataFrame) = { val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // Load training data - val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + val origExamples: DataFrame = loadData(sqlContext, input, dataFormat) // Load or create test set - val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + val dataframes: Array[DataFrame] = if (testInput != "") { // Load testInput. - val numFeatures = origExamples.take(1)(0).features.size - val origTestExamples: RDD[LabeledPoint] = - loadData(sc, testInput, dataFormat, Some(numFeatures)) + val numFeatures = origExamples.first().getAs[Vector](1).size + val origTestExamples: DataFrame = + loadData(sqlContext, testInput, dataFormat, Some(numFeatures)) Array(origExamples, origTestExamples) } else { // Split input into training, test. origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) } - // For classification, convert labels to Strings since we will index them later with - // StringIndexer. - def labelsToStrings(data: DataFrame): DataFrame = { - algo.toLowerCase match { - case "classification" => - data.withColumn("labelString", data("label").cast(StringType)) - case "regression" => - data - case _ => - throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } - } - val dataframes = splits.map(_.toDF()).map(labelsToStrings) val training = dataframes(0).cache() val test = dataframes(1).cache() @@ -230,7 +216,7 @@ object DecisionTreeExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..fc402724d2156 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +// $example off$ +object DecisionTreeRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Automatically identify categorical features, and index them. + // Here, we treat features with > 4 distinct values as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, dt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] + println("Learned regression tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index f4a15f806ea81..6b0be0f34e196 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -153,7 +153,7 @@ object GBTExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index b73299fb12d3f..50998c94de3d0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -131,7 +131,7 @@ object LinearRegressionExample { println(s"Training time: $elapsedTime seconds") // Print the weights and intercept for linear regression. - println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + println(s"Weights: ${lirModel.coefficients} Intercept: ${lirModel.intercept}") println("Training data results:") DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 8e3760ddb50a9..a380c90662a50 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -125,7 +125,7 @@ object LogisticRegressionExample { val stages = new mutable.ArrayBuffer[PipelineStage]() val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol("indexedLabel") stages += labelIndexer @@ -149,7 +149,7 @@ object LogisticRegressionExample { val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] // Print the weights and intercept for logistic regression. - println(s"Weights: ${lorModel.weights} Intercept: ${lorModel.intercept}") + println(s"Weights: ${lorModel.coefficients} Intercept: ${lorModel.intercept}") println("Training data results:") DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala new file mode 100644 index 0000000000000..146b83c8be490 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.sql.SQLContext +// $example on$ +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ + +/** + * An example for Multilayer Perceptron Classification. + */ +object MultilayerPerceptronClassifierExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("MultilayerPerceptronClassifierExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") + // Split the data into train and test + val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) + val train = splits(0) + val test = splits(1) + // specify layers for the neural network: + // input layer of size 4 (features), two intermediate of size 5 and 4 + // and output of size 3 (classes) + val layers = Array[Int](4, 5, 4, 3) + // create the trainer and set its parameters + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100) + // train the model + val model = trainer.fit(train) + // compute precision on the test set + val result = model.transform(test) + val predictionAndLabels = result.select("prediction", "label") + val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision") + println("Precision:" + evaluator.evaluate(predictionAndLabels)) + // $example off$ + + sc.stop() + } +} +// scalastyle:off println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index bab31f585b0ef..8e4f1b09a24b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -27,9 +27,8 @@ import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SQLContext /** @@ -111,24 +110,24 @@ object OneVsRestExample { private def run(params: Params) { val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") val sc = new SparkContext(conf) - val inputData = MLUtils.loadLibSVMFile(sc, params.input) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val inputData = sqlContext.read.format("libsvm").load(params.input) // compute the train/test split: if testInput is not provided use part of input. val data = params.testInput match { case Some(t) => { // compute the number of features in the training set. - val numFeatures = inputData.first().features.size - val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures) - Array[RDD[LabeledPoint]](inputData, testData) + val numFeatures = inputData.first().getAs[Vector](1).size + val testData = sqlContext.read.option("numFeatures", numFeatures.toString) + .format("libsvm").load(t) + Array[DataFrame](inputData, testData) } case None => { val f = params.fracTest inputData.randomSplit(Array(1 - f, f), seed = 12345) } } - val Array(train, test) = data.map(_.toDF().cache()) + val Array(train, test) = data.map(_.cache()) // instantiate the base classifier val classifier = new LogisticRegression() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 109178f4137b2..7a00d99dfe53d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -159,7 +159,7 @@ object RandomForestExample { val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { val labelIndexer = new StringIndexer() - .setInputCol("labelString") + .setInputCol("label") .setOutputCol(labelColName) stages += labelIndexer } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala index 1abdf219b1c00..cd1b0e9358beb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -20,7 +20,6 @@ package org.apache.spark.examples.ml import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} @@ -39,10 +38,9 @@ object TrainValidationSplitExample { val conf = new SparkConf().setAppName("TrainValidationSplitExample") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ // Prepare training and test data. - val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) val lr = new LinearRegression() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..d427bbadaa0c1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeClassificationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "gini" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..fb05e7d9c5065 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "variance" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala new file mode 100644 index 0000000000000..139e1f909bdce --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object GradientBoostingClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreesClassificationExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a GradientBoostedTrees model. + // The defaultParams for Classification use LogLoss by default. + val boostingStrategy = BoostingStrategy.defaultParams("Classification") + boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. + boostingStrategy.treeStrategy.numClasses = 2 + boostingStrategy.treeStrategy.maxDepth = 5 + // Empty categoricalFeaturesInfo indicates all features are continuous. + boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() + + val model = GradientBoostedTrees.train(trainingData, boostingStrategy) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification GBT model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myGradientBoostingClassificationModel") + val sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingClassificationModel") + // $example off$ + } +} +// scalastyle:on println + + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala new file mode 100644 index 0000000000000..3dc86da8e4d2b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object GradientBoostingRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("GradientBoostedTreesRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a GradientBoostedTrees model. + // The defaultParams for Regression use SquaredError by default. + val boostingStrategy = BoostingStrategy.defaultParams("Regression") + boostingStrategy.numIterations = 3 // Note: Use more iterations in practice. + boostingStrategy.treeStrategy.maxDepth = 5 + // Empty categoricalFeaturesInfo indicates all features are continuous. + boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]() + + val model = GradientBoostedTrees.train(trainingData, boostingStrategy) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression GBT model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myGradientBoostingRegressionModel") + val sameModel = GradientBoostedTreesModel.load(sc, + "target/tmp/myGradientBoostingRegressionModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala new file mode 100644 index 0000000000000..61d2e7715f53d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object LBFGSExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("LBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + val numFeatures = data.take(1)(0).features.size + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + + // Append 1 into the training data as intercept. + val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() + + val test = splits(1) + + // Run training algorithm to build the model + val numCorrections = 10 + val convergenceTol = 1e-4 + val maxNumIterations = 20 + val regParam = 0.1 + val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) + + val (weightsWithIntercept, loss) = LBFGS.runLBFGS( + training, + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept) + + val model = new LogisticRegressionModel( + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), + weightsWithIntercept(weightsWithIntercept.size - 1)) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Loss of each step in training process") + loss.foreach(println) + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala new file mode 100644 index 0000000000000..5e55abd5121c4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object RandomForestClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestClassificationExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val numTrees = 3 // Use more in practice. + val featureSubsetStrategy = "auto" // Let the algorithm choose. + val impurity = "gini" + val maxDepth = 4 + val maxBins = 32 + + val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification forest model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myRandomForestClassificationModel") + val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala new file mode 100644 index 0000000000000..a54fb3ab7e37a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.tree.RandomForest +import org.apache.spark.mllib.tree.model.RandomForestModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object RandomForestRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("RandomForestRegressionExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a RandomForest model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val numTrees = 3 // Use more in practice. + val featureSubsetStrategy = "auto" // Let the algorithm choose. + val impurity = "variance" + val maxDepth = 4 + val maxBins = 32 + + val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression forest model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myRandomForestRegressionModel") + val sameModel = RandomForestModel.load(sc, "target/tmp/myRandomForestRegressionModel") + // $example off$ + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala new file mode 100644 index 0000000000000..64e4602465444 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel +import org.apache.spark.mllib.recommendation.Rating +// $example off$ + +object RecommendationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("CollaborativeFilteringExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/als/test.data") + val ratings = data.map(_.split(',') match { case Array(user, item, rate) => + Rating(user.toInt, item.toInt, rate.toDouble) + }) + + // Build the recommendation model using ALS + val rank = 10 + val numIterations = 10 + val model = ALS.train(ratings, rank, numIterations, 0.01) + + // Evaluate the model on rating data + val usersProducts = ratings.map { case Rating(user, product, rate) => + (user, product) + } + val predictions = + model.predict(usersProducts).map { case Rating(user, product, rate) => + ((user, product), rate) + } + val ratesAndPreds = ratings.map { case Rating(user, product, rate) => + ((user, product), rate) + }.join(predictions) + val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => + val err = (r1 - r2) + err * err + }.mean() + println("Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + val sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 02ba1c2eed0f7..a4f847f118b2c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -44,24 +44,12 @@ object StatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels() - val updateFunc = (values: Seq[Int], state: Option[Int]) => { - val currentCount = values.sum - - val previousCount = state.getOrElse(0) - - Some(currentCount + previousCount) - } - - val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) - } - val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the @@ -71,9 +59,16 @@ object StatefulNetworkWordCount { val wordDstream = words.map(x => (x, 1)) // Update the cumulative count using updateStateByKey - // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, - new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) + // This will give a DStream made of state (which is the cumulative count of the words) + val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => { + val sum = one.getOrElse(0) + state.getOption.getOrElse(0) + val output = (word, sum) + state.update(sum) + Some(output) + } + + val stateDstream = wordDstream.trackStateByKey( + StateSpec.function(trackStateFunc).initialState(initialRDD)) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala index 70018c86f92be..fe5dcc8e4b9de 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer +import java.util.{List => JList} import java.util.Collections import scala.collection.JavaConverters._ @@ -59,10 +60,10 @@ private[flume] class FlumeTestUtils { } /** Send data to the flume receiver */ - def writeInput(input: Seq[String], enableCompression: Boolean): Unit = { + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { val testAddress = new InetSocketAddress("localhost", testPort) - val inputEvents = input.map { item => + val inputEvents = input.asScala.map { item => val event = new AvroFlumeEvent event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) event.setHeaders(Collections.singletonMap("test", "header")) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index a2ab320957db3..bfe7548d4f50e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.flume import java.util.concurrent._ -import java.util.{Map => JMap, Collections} +import java.util.{Collections, List => JList, Map => JMap} import scala.collection.mutable.ArrayBuffer @@ -137,7 +137,8 @@ private[flume] class PollingFlumeTestUtils { /** * A Python-friendly method to assert the output */ - def assertOutput(outputHeaders: Seq[JMap[String, String]], outputBodies: Seq[String]): Unit = { + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { require(outputHeaders.size == outputBodies.size) val eventSize = outputHeaders.size if (eventSize != totalEventsPerChannel * channels.size) { @@ -151,8 +152,8 @@ private[flume] class PollingFlumeTestUtils { var found = false var j = 0 while (j < eventSize && !found) { - if (eventBodyToVerify == outputBodies(j) && - eventHeaderToVerify == outputHeaders(j)) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { found = true counter += 1 } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala index 1a900007b696b..79077e4a49e1a 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -37,7 +37,7 @@ class TestOutputStream[T: ClassTag](parent: DStream[T], extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) { + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index ff2fb8eed204c..5fd2711f5f7df 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -120,7 +120,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log case (key, value) => (key.toString, value.toString) }).map(_.asJava) val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) - utils.assertOutput(headers, bodies) + utils.assertOutput(headers.asJava, bodies.asJava) } } finally { ssc.stop() diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 5ffb60bd602f9..f315e0a7ca23c 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -54,7 +54,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w val outputBuffer = startContext(utils.getTestPort(), testCompression) eventually(timeout(10 seconds), interval(100 milliseconds)) { - utils.writeInput(input, testCompression) + utils.writeInput(input.asJava, testCompression) } eventually(timeout(10 seconds), interval(100 milliseconds)) { diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 73091cfe2c09e..163ae92c12c6d 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -31,9 +31,12 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.Function4; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaTrackStateDStream; /** * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 @@ -831,4 +834,44 @@ public void testFlatMapValues() { Assert.assertEquals(expected, result); } + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testTrackStateByAPI() { + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + JavaTrackStateDStream stateDstream = + wordsDstream.trackStateByKey( + StateSpec. function((time, key, value, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + + JavaTrackStateDStream stateDstream2 = + wordsDstream.trackStateByKey( + StateSpec.function((value, state) -> { + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords2 = stateDstream2.stateSnapshots(); + } } diff --git a/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala b/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala new file mode 100644 index 0000000000000..fa0681db41088 --- /dev/null +++ b/extras/java8-tests/src/test/scala/org/apache/spark/JDK8ScalaSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Test cases where JDK8-compiled Scala user code is used with Spark. + */ +class JDK8ScalaSuite extends SparkFunSuite with SharedSparkContext { + test("basic RDD closure test (SPARK-6152)") { + sc.parallelize(1 to 1000).map(x => x * x).count() + } +} diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index ef72d97eae69d..519a920279c97 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -64,6 +64,12 @@ aws-java-sdk ${aws.java.sdk.version} + + com.amazonaws + amazon-kinesis-producer + ${aws.kinesis.producer.version} + test + org.mockito mockito-core diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 000897a4e7290..691c1790b207f 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord import com.amazonaws.services.kinesis.model._ import org.apache.spark._ @@ -210,7 +211,10 @@ class KinesisSequenceRangeIterator( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } - (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) + // De-aggregate records, if KPL was used in producing the records. The KCL automatically + // handles de-aggregation during regular operation. This code path is used during recovery + val recordIterator = UserRecord.deaggregate(getRecordsResult.getRecords) + (recordIterator.iterator().asScala, getRecordsResult.getNextShardIterator) } /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala deleted file mode 100644 index 83a4537559512..0000000000000 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import org.apache.spark.Logging -import org.apache.spark.streaming.Duration -import org.apache.spark.util.{Clock, ManualClock, SystemClock} - -/** - * This is a helper class for managing checkpoint clocks. - * - * @param checkpointInterval - * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) - */ -private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, - currentClock: Clock = new SystemClock()) - extends Logging { - - /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ - val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) - - /** - * Check if it's time to checkpoint based on the current time and the derived time - * for the next checkpoint - * - * @return true if it's time to checkpoint - */ - def shouldCheckpoint(): Boolean = { - new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() - } - - /** - * Advance the checkpoint clock by the checkpoint interval. - */ - def advanceCheckpoint(): Unit = { - checkpointClock.advance(checkpointInterval.milliseconds) - } -} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala new file mode 100644 index 0000000000000..1ca6d4302c2bb --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.concurrent._ + +import scala.util.control.NonFatal + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} + +/** + * This is a helper class for managing Kinesis checkpointing. + * + * @param receiver The receiver that keeps track of which sequence numbers we can checkpoint + * @param checkpointInterval How frequently we will checkpoint to DynamoDB + * @param workerId Worker Id of KCL worker for logging purposes + * @param clock In order to use ManualClocks for the purpose of testing + */ +private[kinesis] class KinesisCheckpointer( + receiver: KinesisReceiver[_], + checkpointInterval: Duration, + workerId: String, + clock: Clock = new SystemClock) extends Logging { + + // a map from shardId's to checkpointers + private val checkpointers = new ConcurrentHashMap[String, IRecordProcessorCheckpointer]() + + private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]() + + private val checkpointerThread: RecurringTimer = startCheckpointerThread() + + /** Update the checkpointer instance to the most recent one for the given shardId. */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + checkpointers.put(shardId, checkpointer) + } + + /** + * Stop tracking the specified shardId. + * + * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown [[ShutdownReason.TERMINATE]], + * we will use that to make the final checkpoint. If `null` is provided, we will not make the + * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]]. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + synchronized { + checkpointers.remove(shardId) + checkpoint(shardId, checkpointer) + } + } + + /** Perform the checkpoint. */ + private def checkpoint(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + try { + if (checkpointer != null) { + receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => + val lastSeqNum = lastCheckpointedSeqNums.get(shardId) + // Kinesis sequence numbers are monotonically increasing strings, therefore we can do + // safely do the string comparison + if (lastSeqNum == null || latestSeqNum > lastSeqNum) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint at sequence number" + + s" $latestSeqNum for shardId $shardId") + lastCheckpointedSeqNums.put(shardId, latestSeqNum) + } + } + } else { + logDebug(s"Checkpointing skipped for shardId $shardId. Checkpointer not set.") + } + } catch { + case NonFatal(e) => + logWarning(s"Failed to checkpoint shardId $shardId to DynamoDB.", e) + } + } + + /** Checkpoint the latest saved sequence numbers for all active shardId's. */ + private def checkpointAll(): Unit = synchronized { + // if this method throws an exception, then the scheduled task will not run again + try { + val shardIds = checkpointers.keys() + while (shardIds.hasMoreElements) { + val shardId = shardIds.nextElement() + checkpoint(shardId, checkpointers.get(shardId)) + } + } catch { + case NonFatal(e) => + logWarning("Failed to checkpoint to DynamoDB.", e) + } + } + + /** + * Start the checkpointer thread with the given checkpoint duration. + */ + private def startCheckpointerThread(): RecurringTimer = { + val period = checkpointInterval.milliseconds + val threadName = s"Kinesis Checkpointer - Worker $workerId" + val timer = new RecurringTimer(clock, period, _ => checkpointAll(), threadName) + timer.start() + logDebug(s"Started checkpointer thread: $threadName") + timer + } + + /** + * Shutdown the checkpointer. Should be called on the onStop of the Receiver. + */ + def shutdown(): Unit = { + // the recurring timer checkpoints for us one last time. + checkpointerThread.stop(interruptTimer = false) + checkpointers.clear() + lastCheckpointedSeqNums.clear() + logInfo("Successfully shutdown Kinesis Checkpointer.") + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 134d627cdaffa..97dbb918573a3 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record @@ -31,8 +31,7 @@ import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkEnv} - +import org.apache.spark.Logging private[kinesis] case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) @@ -127,6 +126,11 @@ private[kinesis] class KinesisReceiver[T]( private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + /** + * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. + */ + @volatile private var kinesisCheckpointer: KinesisCheckpointer = null + /** * Latest sequence number ranges that have been stored successfully. * This is used for checkpointing through KCL */ @@ -141,6 +145,7 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() + kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) // KCL config instance val awsCredProvider = resolveAWSCredentialsProvider() val kinesisClientLibConfiguration = @@ -157,8 +162,8 @@ private[kinesis] class KinesisReceiver[T]( * We're using our custom KinesisRecordProcessor in this case. */ val recordProcessorFactory = new IRecordProcessorFactory { - override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, - workerId, new KinesisCheckpointState(checkpointInterval)) + override def createProcessor: IRecordProcessor = + new KinesisRecordProcessor(receiver, workerId) } worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) @@ -198,6 +203,10 @@ private[kinesis] class KinesisReceiver[T]( logInfo(s"Stopped receiver for workerId $workerId") } workerId = null + if (kinesisCheckpointer != null) { + kinesisCheckpointer.shutdown() + kinesisCheckpointer = null + } } /** Add records of the given shard to the current block being generated */ @@ -207,7 +216,6 @@ private[kinesis] class KinesisReceiver[T]( val metadata = SequenceNumberRange(streamName, shardId, records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) - } } @@ -216,6 +224,25 @@ private[kinesis] class KinesisReceiver[T]( shardIdToLatestStoredSeqNum.get(shardId) } + /** + * Set the checkpointer that will be used to checkpoint sequence numbers to DynamoDB for the + * given shardId. + */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.setCheckpointer(shardId, checkpointer) + } + + /** + * Remove the checkpointer for the given shardId. The provided checkpointer will be used to + * checkpoint one last time for the given shard. If `checkpointer` is `null`, then we will not + * checkpoint. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.removeCheckpointer(shardId, checkpointer) + } + /** * Remember the range of sequence numbers that was added to the currently active block. * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`. diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 1d5178790ec4c..b5b76cb92d866 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -27,26 +27,23 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.apache.spark.Logging +import org.apache.spark.streaming.Duration /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create - * multiple Receivers. + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes - * @param checkpointState represents the checkpoint state including the next checkpoint time. - * It's injected here for mocking purposes. */ -private[kinesis] class KinesisRecordProcessor[T]( - receiver: KinesisReceiver[T], - workerId: String, - checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { +private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], workerId: String) + extends IRecordProcessor with Logging { - // shardId to be populated during initialize() + // shardId populated during initialize() @volatile private var shardId: String = _ @@ -74,34 +71,7 @@ private[kinesis] class KinesisRecordProcessor[T]( try { receiver.addRecords(shardId, batch) logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") - - /* - * - * Checkpoint the sequence number of the last record successfully stored. - * Note that in this current implementation, the checkpointing occurs only when after - * checkpointIntervalMillis from the last checkpoint, AND when there is new record - * to process. This leads to the checkpointing lagging behind what records have been - * stored by the receiver. Ofcourse, this can lead records processed more than once, - * under failures and restarts. - * - * TODO: Instead of checkpointing here, run a separate timer task to perform - * checkpointing so that it checkpoints in a timely manner independent of whether - * new records are available or not. - */ - if (checkpointState.shouldCheckpoint()) { - receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => - /* Perform the checkpoint */ - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) - - /* Update the next checkpoint time */ - checkpointState.advanceCheckpoint() - - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + - s" records for shardId $shardId") - logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") - } - } + receiver.setCheckpointer(shardId, checkpointer) } catch { case NonFatal(e) => { /* @@ -110,7 +80,7 @@ private[kinesis] class KinesisRecordProcessor[T]( * more than once. */ logError(s"Exception: WorkerId $workerId encountered and exception while storing " + - " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + s" or checkpointing a batch for workerId $workerId and shardId $shardId.", e) /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e @@ -142,23 +112,18 @@ private[kinesis] class KinesisRecordProcessor[T]( * It's now OK to read from the new shards that resulted from a resharding event. */ case ShutdownReason.TERMINATE => - val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId) - if (latestSeqNumToCheckpointOption.nonEmpty) { - KinesisRecordProcessor.retryRandom( - checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100) - } + receiver.removeCheckpointer(shardId, checkpointer) /* - * ZOMBIE Use Case. NoOp. + * ZOMBIE Use Case or Unknown reason. NoOp. * No checkpoint because other workers may have taken over and already started processing * the same records. * This may lead to records being processed more than once. */ - case ShutdownReason.ZOMBIE => - - /* Unknown reason. NoOp */ case _ => + receiver.removeCheckpointer(shardId, null) // return null so that we don't checkpoint } + } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 9f9e146a08d46..52c61dfb1c023 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -22,7 +22,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.{SparkConf, SparkContext, SparkException} -class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { +abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) + extends KinesisFunSuite with BeforeAndAfterAll { private val testData = 1 to 8 @@ -37,13 +38,12 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll private var sc: SparkContext = null private var blockManager: BlockManager = null - override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { testUtils = new KinesisTestUtils() testUtils.createStream() - shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq @@ -247,3 +247,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll Array.tabulate(num) { i => new StreamBlockId(0, i) } } } + +class WithAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = true) + +class WithoutAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = false) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala new file mode 100644 index 0000000000000..645e64a0bc3a0 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.util.concurrent.{TimeoutException, ExecutorService} + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.ManualClock + +class KinesisCheckpointerSuite extends TestSuiteBase + with MockitoSugar + with BeforeAndAfterEach + with PrivateMethodTester + with Eventually { + + private val workerId = "dummyWorkerId" + private val shardId = "dummyShardId" + private val seqNum = "123" + private val otherSeqNum = "245" + private val checkpointInterval = Duration(10) + private val someSeqNum = Some(seqNum) + private val someOtherSeqNum = Some(otherSeqNum) + + private var receiverMock: KinesisReceiver[Array[Byte]] = _ + private var checkpointerMock: IRecordProcessorCheckpointer = _ + private var kinesisCheckpointer: KinesisCheckpointer = _ + private var clock: ManualClock = _ + + private val checkpoint = PrivateMethod[Unit]('checkpoint) + + override def beforeEach(): Unit = { + receiverMock = mock[KinesisReceiver[Array[Byte]]] + checkpointerMock = mock[IRecordProcessorCheckpointer] + clock = new ManualClock() + kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock) + } + + test("checkpoint is not called twice for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("checkpoint is called after sequence number increases") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(5 * checkpointInterval.milliseconds) + + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds / 2) + + verify(checkpointerMock, never()).checkpoint(anyString()) + } + + test("should not checkpoint for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + + clock.advance(checkpointInterval.milliseconds * 5) + eventually(timeout(1 second)) { + verify(checkpointerMock, atMost(1)).checkpoint(anyString()) + } + } + + test("removing checkpointer checkpoints one last time") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock) + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("if checkpointing is going on, wait until finished before removing and checkpointing") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2) + } + }) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + // don't block test thread + val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock))( + ExecutionContext.global) + + intercept[TimeoutException] { + Await.ready(f, 50 millis) + } + + clock.advance(checkpointInterval.milliseconds / 2) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(2)).checkpoint(anyString()) + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 17ab444704f44..e5c70db554a27 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -25,12 +25,13 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.{BeforeAndAfter, Matchers} -import org.apache.spark.streaming.{Milliseconds, TestSuiteBase} -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -44,6 +45,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val workerId = "dummyWorkerId" val shardId = "dummyShardId" val seqNum = "dummySeqNum" + val checkpointInterval = Duration(10) val someSeqNum = Some(seqNum) val record1 = new Record() @@ -54,24 +56,10 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var receiverMock: KinesisReceiver[Array[Byte]] = _ var checkpointerMock: IRecordProcessorCheckpointer = _ - var checkpointClockMock: ManualClock = _ - var checkpointStateMock: KinesisCheckpointState = _ - var currentClockMock: Clock = _ override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver[Array[Byte]]] checkpointerMock = mock[IRecordProcessorCheckpointer] - checkpointClockMock = mock[ManualClock] - checkpointStateMock = mock[KinesisCheckpointState] - currentClockMock = mock[Clock] - } - - override def afterFunction(): Unit = { - super.afterFunction() - // Since this suite was originally written using EasyMock, add this to preserve the old - // mocking semantics (see SPARK-5735 for more details) - verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, - checkpointStateMock, currentClockMock) } test("check serializability of SerializableAWSCredentials") { @@ -79,113 +67,67 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft Utils.serialize(new SerializableAWSCredentials("x", "y"))) } - test("process records including store and checkpoint") { + test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) - when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointStateMock, times(1)).shouldCheckpoint() - verify(checkpointerMock, times(1)).checkpoint(anyString) - verify(checkpointStateMock, times(1)).advanceCheckpoint() + verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock) } - test("shouldn't store and checkpoint when receiver is stopped") { + test("shouldn't store and update checkpointer when receiver is stopped") { when(receiverMock.isStopped()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) - verify(checkpointerMock, never).checkpoint(anyString) + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } - test("shouldn't checkpoint when exception occurs during store") { + test("shouldn't update checkpointer when exception occurs during store") { when(receiverMock.isStopped()).thenReturn(false) when( receiverMock.addRecords(anyString, anyListOf(classOf[Record])) ).thenThrow(new RuntimeException()) intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) } verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(checkpointerMock, never).checkpoint(anyString) - } - - test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should checkpoint if we have exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should add to time when advancing checkpoint") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) - - verify(currentClockMock, times(1)).getTimeMillis() + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } test("shutdown should checkpoint if the reason is TERMINATE") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointerMock, times(1)).checkpoint(anyString) + verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) } + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) recordProcessor.shutdown(checkpointerMock, null) - verify(checkpointerMock, never).checkpoint(anyString) + verify(receiverMock, times(2)).removeCheckpointer(meq(shardId), + meq[IRecordProcessorCheckpointer](null)) } test("retry success on first attempt") { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ba84e557dfcc2..dee30444d8cc6 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils import org.apache.spark.{SparkConf, SparkContext} -class KinesisStreamSuite extends KinesisFunSuite +abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL will use to save metadata to DynamoDB @@ -182,13 +182,13 @@ class KinesisStreamSuite extends KinesisFunSuite val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + logInfo("Collected = " + collected.mkString(", ")) } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) + testUtils.pushData(testData, aggregateTestData) assert(collected === testData.toSet, "\nData received does not match data sent") } ssc.stop(stopSparkContext = false) @@ -207,13 +207,13 @@ class KinesisStreamSuite extends KinesisFunSuite val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.foreachRDD { rdd => collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + logInfo("Collected = " + collected.mkString(", ")) } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) + testUtils.pushData(testData, aggregateTestData) val modData = testData.map(_ + 5) assert(collected === modData.toSet, "\nData received does not match data sent") } @@ -254,7 +254,7 @@ class KinesisStreamSuite extends KinesisFunSuite // If this times out because numBatchesWithData is empty, then its likely that foreachRDD // function failed with exceptions, and nothing got added to `collectedData` eventually(timeout(2 minutes), interval(1 seconds)) { - testUtils.pushData(1 to 5) + testUtils.pushData(1 to 5, aggregateTestData) assert(isCheckpointPresent && numBatchesWithData > 10) } ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused @@ -285,5 +285,8 @@ class KinesisStreamSuite extends KinesisFunSuite } ssc.stop() } - } + +class WithAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = true) + +class WithoutAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = false) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 80% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 634bf94521079..7487aa1c12639 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -31,6 +31,8 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ +import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.Logging @@ -64,6 +66,16 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } + private lazy val kinesisProducer: KinesisProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KinesisProducer(conf) + } + def streamName: String = { require(streamCreated, "Stream not yet created, call createStream() to create one") _streamName @@ -90,22 +102,41 @@ private[kinesis] class KinesisTestUtils extends Logging { * Push data to Kinesis stream and return a map of * shardId -> seq of (data, seq number) pushed to corresponding shard */ - def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() testData.foreach { num => val str = num.toString - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(ByteBuffer.wrap(str.getBytes())) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) + val data = ByteBuffer.wrap(str.getBytes()) + if (aggregate) { + val future = kinesisProducer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + + Futures.addCallback(future, kinesisCallBack) + kinesisProducer.flushSync() // make sure we send all data before returning the map + } else { + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } } logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") @@ -116,7 +147,7 @@ private[kinesis] class KinesisTestUtils extends Logging { * Expose a Python friendly API. */ def pushData(testData: java.util.List[Int]): Unit = { - pushData(testData.asScala) + pushData(testData.asScala, aggregate = false) } def deleteStream(): Unit = { diff --git a/graphx/pom.xml b/graphx/pom.xml index 987b831021a54..8cd66c5b2e826 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -47,6 +47,10 @@ test-jar test + + org.apache.xbean + xbean-asm5-shaded + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 9451ff1e5c0e2..9827dfab8684a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -282,7 +282,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Convert bi-directional edges into uni-directional ones. * Some graph algorithms (e.g., TriangleCount) assume that an input graph * has its edges in canonical direction. - * This function rewrites the vertex ids of edges so that srcIds are bigger + * This function rewrites the vertex ids of edges so that srcIds are smaller * than dstIds, and merges the duplicated edges. * * @param mergeFunc the user defined reduce function which should diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 74a7de18d4161..a6d0cb6409664 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -22,11 +22,10 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.HashSet import scala.language.existentials -import org.apache.spark.util.Utils - -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor} +import org.apache.xbean.asm5.Opcodes._ +import org.apache.spark.util.Utils /** * Includes an utility function to test whether a function accesses a specific attribute @@ -107,18 +106,19 @@ private[graphx] object BytecodeUtils { * MethodInvocationFinder("spark/graph/Foo", "test") * its methodsInvoked variable will contain the set of methods invoked directly by * Foo.test(). Interface invocations are not returned as part of the result set because we cannot - * determine the actual metod invoked by inspecting the bytecode. + * determine the actual method invoked by inspecting the bytecode. */ private class MethodInvocationFinder(className: String, methodName: String) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { val methodsInvoked = new HashSet[(Class[_], String)] override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { if (name == methodName) { - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKEVIRTUAL || op == INVOKESPECIAL || op == INVOKESTATIC) { if (!skipClass(owner)) { methodsInvoked.add((Utils.classForName(owner.replace("/", ".")), name)) diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java index de50f14fbdc87..1bfda289dec39 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java @@ -18,6 +18,7 @@ package org.apache.spark.launcher; import java.io.IOException; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ThreadFactory; @@ -102,8 +103,20 @@ public synchronized void kill() { disconnect(); } if (childProc != null) { - childProc.destroy(); - childProc = null; + try { + childProc.exitValue(); + } catch (IllegalThreadStateException e) { + // Child is still alive. Try to use Java 8's "destroyForcibly()" if available, + // fall back to the old API if it's not there. + try { + Method destroy = childProc.getClass().getMethod("destroyForcibly"); + destroy.invoke(childProc); + } catch (Exception inner) { + childProc.destroy(); + } + } finally { + childProc = null; + } } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java index 13dd9f1739fb6..e9caf0b3cb063 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -89,6 +89,9 @@ public boolean isFinal() { * Tries to kill the underlying application. Implies {@link #disconnect()}. This will not send * a {@link #stop()} message to the application, so it's recommended that users first try to * stop the application cleanly and only resort to this method if that fails. + *

    + * Note that if the application is running as a child process, this method fail to kill the + * process when using Java 7. This may happen if, for example, the application is deadlocked. */ void kill(); diff --git a/make-distribution.sh b/make-distribution.sh index e1c2afdbc6d87..d7d27e253f721 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -220,6 +220,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib + cp "$SPARK_HOME/R/lib/sparkr.zip" "$DISTDIR"/R/lib fi # Download and copy in tachyon, if requested diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f5fca686df144..a88f52674102c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,13 +21,14 @@ import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics @@ -396,7 +397,7 @@ class LogisticRegressionModel private[ml] ( val coefficients: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] - with LogisticRegressionParams { + with LogisticRegressionParams with Writable { @deprecated("Use coefficients instead.", "1.6.0") def weights: Vector = coefficients @@ -510,8 +511,71 @@ class LogisticRegressionModel private[ml] ( // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden. if (probability(1) > getThreshold) 1 else 0 } + + /** + * Returns a [[Writer]] instance for this ML instance. + * + * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + */ + override def write: Writer = new LogisticRegressionWriter(this) +} + + +/** [[Writer]] instance for [[LogisticRegressionModel]] */ +private[classification] class LogisticRegressionWriter(instance: LogisticRegressionModel) + extends Writer with Logging { + + private case class Data( + numClasses: Int, + numFeatures: Int, + intercept: Double, + coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: numClasses, numFeatures, intercept, coefficients + val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, + instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + } +} + + +object LogisticRegressionModel extends Readable[LogisticRegressionModel] { + + override def read: Reader[LogisticRegressionModel] = new LogisticRegressionReader + + override def load(path: String): LogisticRegressionModel = read.load(path) } + +private[classification] class LogisticRegressionReader extends Reader[LogisticRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + + override def load(path: String): LogisticRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.format("parquet").load(dataPath) + .select("numClasses", "numFeatures", "intercept", "coefficients").head() + // We will need numClasses, numFeatures in the future for multinomial logreg support. + // val numClasses = data.getInt(0) + // val numFeatures = data.getInt(1) + val intercept = data.getDouble(2) + val coefficients = data.getAs[Vector](3) + val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } +} + + /** * MultiClassSummarizer computes the number of distinct labels and corresponding counts, * and validates the data to see if the labels used for k class multi-label classification diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala new file mode 100644 index 0000000000000..92e05815d6a3d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -0,0 +1,709 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, + EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, + LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, + OnlineLDAOptimizer => OldOnlineLDAOptimizer} +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors, Matrix, Vector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} +import org.apache.spark.sql.types.StructType + + +private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter + with HasSeed with HasCheckpointInterval { + + /** + * Param for the number of topics (clusters) to infer. Must be > 1. Default: 10. + * @group param + */ + @Since("1.6.0") + final val k = new IntParam(this, "k", "number of topics (clusters) to infer", + ParamValidators.gt(1)) + + /** @group getParam */ + @Since("1.6.0") + def getK: Int = $(k) + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing + * (more regularization). + * + * If not set by the user, then docConcentration is set automatically. If set to + * singleton vector [alpha], then alpha is replicated to a vector of length k in fitting. + * Otherwise, the [[docConcentration]] vector must be length k. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Currently only supports symmetric distributions, so all values in the vector should be + * the same. + * - Values should be > 1.0 + * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows + * from Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Values should be >= 0 + * - default = uniformly (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val docConcentration = new DoubleArrayParam(this, "docConcentration", + "Concentration parameter (commonly named \"alpha\") for the prior placed on documents'" + + " distributions over topics (\"theta\").", (alpha: Array[Double]) => alpha.forall(_ >= 0.0)) + + /** @group getParam */ + @Since("1.6.0") + def getDocConcentration: Array[Double] = $(docConcentration) + + /** Get docConcentration used by spark.mllib LDA */ + protected def getOldDocConcentration: Vector = { + if (isSet(docConcentration)) { + Vectors.dense(getDocConcentration) + } else { + Vectors.dense(-1.0) + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * If not set by the user, then topicConcentration is set automatically. + * (default = automatic) + * + * Optimizer-specific parameter settings: + * - EM + * - Value should be > 1.0 + * - default = 0.1 + 1, where 0.1 gives a small amount of smoothing and +1 follows + * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Online + * - Value should be >= 0 + * - default = (1.0 / k), following the implementation from + * [[https://github.com/Blei-Lab/onlineldavb]]. + * @group param + */ + @Since("1.6.0") + final val topicConcentration = new DoubleParam(this, "topicConcentration", + "Concentration parameter (commonly named \"beta\" or \"eta\") for the prior placed on topic'" + + " distributions over terms.", ParamValidators.gtEq(0)) + + /** @group getParam */ + @Since("1.6.0") + def getTopicConcentration: Double = $(topicConcentration) + + /** Get topicConcentration used by spark.mllib LDA */ + protected def getOldTopicConcentration: Double = { + if (isSet(topicConcentration)) { + getTopicConcentration + } else { + -1.0 + } + } + + /** Supported values for Param [[optimizer]]. */ + @Since("1.6.0") + final val supportedOptimizers: Array[String] = Array("online", "em") + + /** + * Optimizer or inference algorithm used to estimate the LDA model. + * Currently supported (case-insensitive): + * - "online": Online Variational Bayes (default) + * - "em": Expectation-Maximization + * + * For details, see the following papers: + * - Online LDA: + * Hoffman, Blei and Bach. "Online Learning for Latent Dirichlet Allocation." + * Neural Information Processing Systems, 2010. + * [[http://www.cs.columbia.edu/~blei/papers/HoffmanBleiBach2010b.pdf]] + * - EM: + * Asuncion et al. "On Smoothing and Inference for Topic Models." + * Uncertainty in Artificial Intelligence, 2009. + * [[http://arxiv.org/pdf/1205.2662.pdf]] + * + * @group param + */ + @Since("1.6.0") + final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), + (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) + + /** @group getParam */ + @Since("1.6.0") + def getOptimizer: String = $(optimizer) + + /** + * Output column with estimates of the topic mixture distribution for each document (often called + * "theta" in the literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @group param + */ + @Since("1.6.0") + final val topicDistributionCol = new Param[String](this, "topicDistribution", "Output column" + + " with estimates of the topic mixture distribution for each document (often called \"theta\"" + + " in the literature). Returns a vector of zeros for an empty document.") + + setDefault(topicDistributionCol -> "topicDistribution") + + /** @group getParam */ + @Since("1.6.0") + def getTopicDistributionCol: String = $(topicDistributionCol) + + /** + * A (positive) learning parameter that downweights early iterations. Larger values make early + * iterations count less. + * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010) + * Default: 1024, following Hoffman et al. + * @group expertParam + */ + @Since("1.6.0") + final val learningOffset = new DoubleParam(this, "learningOffset", "A (positive) learning" + + " parameter that downweights early iterations. Larger values make early iterations count less.", + ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningOffset: Double = $(learningOffset) + + /** + * Learning rate, set as an exponential decay rate. + * This should be between (0.5, 1.0] to guarantee asymptotic convergence. + * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010). + * Default: 0.51, based on Hoffman et al. + * @group expertParam + */ + @Since("1.6.0") + final val learningDecay = new DoubleParam(this, "learningDecay", "Learning rate, set as an" + + " exponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic" + + " convergence.", ParamValidators.gt(0)) + + /** @group expertGetParam */ + @Since("1.6.0") + def getLearningDecay: Double = $(learningDecay) + + /** + * Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, + * in range (0, 1]. + * + * Note that this should be adjusted in synch with [[LDA.maxIter]] + * so the entire corpus is used. Specifically, set both so that + * maxIterations * miniBatchFraction >= 1. + * + * Note: This is the same as the `miniBatchFraction` parameter in + * [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]]. + * + * Default: 0.05, i.e., 5% of total documents. + * @group param + */ + @Since("1.6.0") + final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "Fraction of the corpus" + + " to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].", + ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + @Since("1.6.0") + def getSubsamplingRate: Double = $(subsamplingRate) + + /** + * Indicates whether the docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. + * Setting this to true will make the model more expressive and fit the training data better. + * Default: false + * @group expertParam + */ + @Since("1.6.0") + final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration", + "Indicates whether the docConcentration (Dirichlet parameter for document-topic" + + " distribution) will be optimized during training.") + + /** @group expertGetParam */ + @Since("1.6.0") + def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) + } + + @Since("1.6.0") + override def validateParams(): Unit = { + if (isSet(docConcentration)) { + if (getDocConcentration.length != 1) { + require(getDocConcentration.length == getK, s"LDA docConcentration was of length" + + s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" + + s" length either 1 (scalar) or k (num topics).") + } + getOptimizer match { + case "online" => + require(getDocConcentration.forall(_ >= 0), + "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " + + getDocConcentration.mkString(",")) + case "em" => + require(getDocConcentration.forall(_ >= 0), + "For EM optimizer, docConcentration values must be >= 1. Found values: " + + getDocConcentration.mkString(",")) + } + } + if (isSet(topicConcentration)) { + getOptimizer match { + case "online" => + require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" + + s" must be >= 0. Found value: $getTopicConcentration") + case "em" => + require(getTopicConcentration >= 0, s"For EM optimizer, topicConcentration" + + s" must be >= 1. Found value: $getTopicConcentration") + } + } + } + + private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { + case "online" => + new OldOnlineLDAOptimizer() + .setTau0($(learningOffset)) + .setKappa($(learningDecay)) + .setMiniBatchFraction($(subsamplingRate)) + .setOptimizeDocConcentration($(optimizeDocConcentration)) + case "em" => + new OldEMLDAOptimizer() + } +} + + +/** + * :: Experimental :: + * Model fitted by [[LDA]]. + * + * @param vocabSize Vocabulary size (number of terms or terms in the vocabulary) + * @param sqlContext Used to construct local DataFrames for returning query results + */ +@Since("1.6.0") +@Experimental +sealed abstract class LDAModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val vocabSize: Int, + @Since("1.6.0") @transient protected val sqlContext: SQLContext) + extends Model[LDAModel] with LDAParams with Logging { + + // NOTE to developers: + // This abstraction should contain all important functionality for basic LDA usage. + // Specializations of this class can contain expert-only functionality. + + /** + * Underlying spark.mllib model. + * If this model was produced by Online LDA, then this is the only model representation. + * If this model was produced by EM, then this local representation may be built lazily. + */ + @Since("1.6.0") + protected def oldLocalModel: OldLocalLDAModel + + /** Returns underlying spark.mllib model, which may be local or distributed */ + @Since("1.6.0") + protected def getModel: OldLDAModel + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Transforms the input dataset. + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + */ + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + if ($(topicDistributionCol).nonEmpty) { + val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) + } else { + logWarning("LDAModel.transform was called without any output columns. Set an output column" + + " such as topicDistributionCol to produce results.") + dataset + } + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + /** + * Value for [[docConcentration]] estimated from data. + * If Online LDA was used and [[optimizeDocConcentration]] was set to false, + * then this returns the fixed (given) value for the [[docConcentration]] parameter. + */ + @Since("1.6.0") + def estimatedDocConcentration: Vector = getModel.docConcentration + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by + * the Expectation-Maximization ("em") [[optimizer]], then this method could involve + * collecting a large amount of data to the driver (on the order of vocabSize x k). + */ + @Since("1.6.0") + def topicsMatrix: Matrix = oldLocalModel.topicsMatrix + + /** Indicates whether this instance is of type [[DistributedLDAModel]] */ + @Since("1.6.0") + def isDistributed: Boolean + + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + * + * @param dataset test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + @Since("1.6.0") + def logLikelihood(dataset: DataFrame): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logLikelihood(oldDataset) + } + + /** + * Calculate an upper bound bound on perplexity. (Lower is better.) + * See Equation (16) in the Online LDA paper (Hoffman et al., 2010). + * + * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]] + * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. + * This implementation may be changed in the future. + * + * @param dataset test corpus to use for calculating perplexity + * @return Variational upper bound on log perplexity per token. + */ + @Since("1.6.0") + def logPerplexity(dataset: DataFrame): Double = { + val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) + oldLocalModel.logPerplexity(oldDataset) + } + + /** + * Return the topics described by their top-weighted terms. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * Default value of 10. + * @return Local DataFrame with one topic per Row, with columns: + * - "topic": IntegerType: topic index + * - "termIndices": ArrayType(IntegerType): term indices, sorted in order of decreasing + * term importance + * - "termWeights": ArrayType(DoubleType): corresponding sorted term weights + */ + @Since("1.6.0") + def describeTopics(maxTermsPerTopic: Int): DataFrame = { + val topics = getModel.describeTopics(maxTermsPerTopic).zipWithIndex.map { + case ((termIndices, termWeights), topic) => + (topic, termIndices.toSeq, termWeights.toSeq) + } + sqlContext.createDataFrame(topics).toDF("topic", "termIndices", "termWeights") + } + + @Since("1.6.0") + def describeTopics(): DataFrame = describeTopics(10) +} + + +/** + * :: Experimental :: + * + * Local (non-distributed) model fitted by [[LDA]]. + * + * This model stores the inferred topics only; it does not store info about the training dataset. + */ +@Since("1.6.0") +@Experimental +class LocalLDAModel private[ml] ( + uid: String, + vocabSize: Int, + @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel, + sqlContext: SQLContext) + extends LDAModel(uid, vocabSize, sqlContext) { + + @Since("1.6.0") + override def copy(extra: ParamMap): LocalLDAModel = { + val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel] + } + + override protected def getModel: OldLDAModel = oldLocalModel + + @Since("1.6.0") + override def isDistributed: Boolean = false +} + + +/** + * :: Experimental :: + * + * Distributed model fitted by [[LDA]]. + * This type of model is currently only produced by Expectation-Maximization (EM). + * + * This model stores the inferred topics, the full training dataset, and the topic distribution + * for each training document. + * + * @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping + * [[copy()]] cheap. + */ +@Since("1.6.0") +@Experimental +class DistributedLDAModel private[ml] ( + uid: String, + vocabSize: Int, + private val oldDistributedModel: OldDistributedLDAModel, + sqlContext: SQLContext, + private var oldLocalModelOption: Option[OldLocalLDAModel]) + extends LDAModel(uid, vocabSize, sqlContext) { + + override protected def oldLocalModel: OldLocalLDAModel = { + if (oldLocalModelOption.isEmpty) { + oldLocalModelOption = Some(oldDistributedModel.toLocal) + } + oldLocalModelOption.get + } + + override protected def getModel: OldLDAModel = oldDistributedModel + + /** + * Convert this distributed model to a local representation. This discards info about the + * training dataset. + * + * WARNING: This involves collecting a large [[topicsMatrix]] to the driver. + */ + @Since("1.6.0") + def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext) + + @Since("1.6.0") + override def copy(extra: ParamMap): DistributedLDAModel = { + val copied = + new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption) + copyValues(copied, extra).setParent(parent) + copied + } + + @Since("1.6.0") + override def isDistributed: Boolean = true + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) + * + * Notes: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + * - This is computed from the topic distributions computed during training. If you call + * [[logLikelihood()]] on the same training dataset, the topic distributions will be computed + * again, possibly giving different results. + */ + @Since("1.6.0") + lazy val trainingLogLikelihood: Double = oldDistributedModel.logLikelihood + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | Dirichlet hyperparameters) + */ + @Since("1.6.0") + lazy val logPrior: Double = oldDistributedModel.logPrior +} + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "term" = "word": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over terms representing some concept + * - "document": one piece of text, corresponding to one row in the input data + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * + * Input data (featuresCol): + * LDA is given a collection of documents as input data, via the featuresCol parameter. + * Each document is specified as a [[Vector]] of length vocabSize, where each entry is the + * count for the corresponding term (word) in the document. Feature transformers such as + * [[org.apache.spark.ml.feature.Tokenizer]] and [[org.apache.spark.ml.feature.CountVectorizer]] + * can be useful for converting text to word count vectors. + * + * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation + * (Wikipedia)]] + */ +@Since("1.6.0") +@Experimental +class LDA @Since("1.6.0") ( + @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("lda")) + + setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10, + learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05, + optimizeDocConcentration -> true) + + /** + * The features for LDA should be a [[Vector]] representing the word counts in a document. + * The vector should be of length vocabSize, with counts for each term (word). + * @group setParam + */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("1.6.0") + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("1.6.0") + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.6.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Array[Double]): this.type = set(docConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setDocConcentration(value: Double): this.type = set(docConcentration, Array(value)) + + /** @group setParam */ + @Since("1.6.0") + def setTopicConcentration(value: Double): this.type = set(topicConcentration, value) + + /** @group setParam */ + @Since("1.6.0") + def setOptimizer(value: String): this.type = set(optimizer, value) + + /** @group setParam */ + @Since("1.6.0") + def setTopicDistributionCol(value: String): this.type = set(topicDistributionCol, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningOffset(value: Double): this.type = set(learningOffset, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setLearningDecay(value: Double): this.type = set(learningDecay, value) + + /** @group setParam */ + @Since("1.6.0") + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group expertSetParam */ + @Since("1.6.0") + def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value) + + @Since("1.6.0") + override def copy(extra: ParamMap): LDA = defaultCopy(extra) + + @Since("1.6.0") + override def fit(dataset: DataFrame): LDAModel = { + transformSchema(dataset.schema, logging = true) + val oldLDA = new OldLDA() + .setK($(k)) + .setDocConcentration(getOldDocConcentration) + .setTopicConcentration(getOldTopicConcentration) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setCheckpointInterval($(checkpointInterval)) + .setOptimizer(getOldOptimizer) + // TODO: persist here, or in old LDA? + val oldData = LDA.getOldDataset(dataset, $(featuresCol)) + val oldModel = oldLDA.run(oldData) + val newModel = oldModel match { + case m: OldLocalLDAModel => + new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext) + case m: OldDistributedLDAModel => + new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None) + } + copyValues(newModel).setParent(this) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + + +private[clustering] object LDA { + + /** Get dataset for spark.mllib LDA */ + def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { + dataset + .withColumn("docId", monotonicallyIncreasingId()) + .select("docId", featuresCol) + .map { case Row(docId: Long, features: Vector) => + (docId, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 248288ca73e99..1b82b40caac18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -100,10 +100,25 @@ class RegexTokenizer(override val uid: String) /** @group getParam */ def getPattern: String = $(pattern) - setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") + /** + * Indicates whether to convert all characters to lowercase before tokenizing. + * Default: true + * @group param + */ + final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase", + "whether to convert all characters to lowercase before tokenizing.") + + /** @group setParam */ + def setToLowercase(value: Boolean): this.type = set(toLowercase, value) + + /** @group getParam */ + def getToLowercase: Boolean = $(toLowercase) + + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) - override protected def createTransformFunc: String => Seq[String] = { str => + override protected def createTransformFunc: String => Seq[String] = { originStr => val re = $(pattern).r + val str = if ($(toLowercase)) originStr.toLowerCase() else originStr val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq val minLength = $(minTokenLength) tokens.filter(_.length >= minLength) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 9edab3af913ca..708dbeef84db4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types._ /** @@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] @Experimental class Word2VecModel private[ml] ( override val uid: String, - wordVectors: feature.Word2VecModel) + @transient private val wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { - /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. @@ -197,22 +194,23 @@ class Word2VecModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val vectors = wordVectors.getVectors + .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) + .map(identity) // mapValues doesn't return a serializable map (SI-7005) + val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors) + val d = $(vectorSize) val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { - Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) + Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) } else { - val cum = Vectors.zeros($(vectorSize)) - val model = bWordVectors.value.getVectors - for (word <- sentence) { - if (model.contains(word)) { - axpy(1.0, bWordVectors.value.transform(word), cum) - } else { - // pass words which not belong to model + val sum = Vectors.zeros(d) + sentence.foreach { word => + bVectors.value.get(word).foreach { v => + BLAS.axpy(1.0, v, sum) } } - scal(1.0 / sentence.size, cum) - cum + BLAS.scal(1.0 / sentence.size, sum) + sum } } dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 5be2f86936211..4d82b90bfdf20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -52,11 +52,36 @@ private[r] object SparkRWrappers { } def getModelCoefficients(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => { + val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ + m.summary.coefficientStandardErrors.dropRight(1) + val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) + val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1) + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++ + tValuesR ++ pValuesR + } else { + m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR + } + } + case m: LogisticRegressionModel => { + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray + } else { + m.coefficients.toArray + } + } + } + } + + def getModelDevianceResiduals(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray + m.summary.devianceResiduals case m: LogisticRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray + throw new UnsupportedOperationException( + "No deviance residuals available for LogisticRegressionModel") } } @@ -65,11 +90,28 @@ private[r] object SparkRWrappers { case m: LinearRegressionModel => val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } case m: LogisticRegressionModel => val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } + } + } + + def getModelName(model: PipelineModel): String = { + model.stages.last match { + case m: LinearRegressionModel => + "LinearRegressionModel" + case m: LogisticRegressionModel => + "LogisticRegressionModel" } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 1f627777fc68d..11b9815ecc832 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -82,7 +82,7 @@ private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val * .load("data/mllib/sample_libsvm_data.txt") * * // Java - * DataFrame df = sqlContext.read.format("libsvm") + * DataFrame df = sqlContext.read().format("libsvm") * .option("numFeatures, "780") * .load("data/mllib/sample_libsvm_data.txt"); * }}} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ea790e0dddc7f..ca896ed6106c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -48,9 +48,15 @@ private[util] sealed trait BaseReadWrite { /** * Returns the user-specified SQL context or the default. */ - protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { - SQLContext.getOrCreate(SparkContext.getOrCreate()) + protected final def sqlContext: SQLContext = { + if (optionSQLContext.isEmpty) { + optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate())) + } + optionSQLContext.get } + + /** Returns the [[SparkContext]] underlying [[sqlContext]] */ + protected final def sc: SparkContext = sqlContext.sparkContext } /** @@ -58,7 +64,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite { +abstract class Writer extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -67,7 +73,29 @@ abstract class Writer extends BaseReadWrite { */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String): Unit + def save(path: String): Unit = { + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + saveImpl(path) + } + + /** + * [[save()]] handles overwriting and then calls this method. Subclasses should override this + * method to implement the actual saving of the instance. + */ + @Since("1.6.0") + protected def saveImpl(path: String): Unit /** * Overwrites if the output path already exists. @@ -147,28 +175,24 @@ trait Readable[T] { * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { +private[ml] class DefaultParamsWriter(instance: Params) extends Writer { - /** - * Saves the ML component to the input path. - */ - override def save(path: String): Unit = { - val sc = sqlContext.sparkContext + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + } +} - val hadoopConf = sc.hadoopConfiguration - val fs = FileSystem.get(hadoopConf) - val p = new Path(path) - if (fs.exists(p)) { - if (shouldOverwrite) { - logInfo(s"Path $path already exists. It will be overwritten.") - // TODO: Revert back to the original content if save is not successful. - fs.delete(p, true) - } else { - throw new IOException( - s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") - } - } +private[ml] object DefaultParamsWriter { + /** + * Saves metadata + Params to: path + "/metadata" + * - class + * - timestamp + * - sparkVersion + * - uid + * - paramMap + */ + def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -177,6 +201,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg }.toList val metadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) val metadataPath = new Path(path, "metadata").toString @@ -193,19 +218,62 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg */ private[ml] class DefaultParamsReader[T] extends Reader[T] { + override def load(path: String): T = { + val metadata = DefaultParamsReader.loadMetadata(path, sc) + val cls = Utils.classForName(metadata.className) + val instance = + cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params] + DefaultParamsReader.getAndSetParams(instance, metadata) + instance.asInstanceOf[T] + } +} + +private[ml] object DefaultParamsReader { + /** - * Loads the ML component from the input path. + * All info from metadata file. + * @param params paramMap, as a [[JValue]] + * @param metadataStr Full metadata file String (for debugging) */ - override def load(path: String): T = { - implicit val format = DefaultFormats - val sc = sqlContext.sparkContext + case class Metadata( + className: String, + uid: String, + timestamp: Long, + sparkVersion: String, + params: JValue, + metadataStr: String) + + /** + * Load metadata from file. + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() val metadata = parse(metadataStr) - val cls = Utils.classForName((metadata \ "class").extract[String]) + + implicit val format = DefaultFormats + val className = (metadata \ "class").extract[String] val uid = (metadata \ "uid").extract[String] - val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] - (metadata \ "paramMap") match { + val timestamp = (metadata \ "timestamp").extract[Long] + val sparkVersion = (metadata \ "sparkVersion").extract[String] + val params = metadata \ "paramMap" + if (expectedClassName.nonEmpty) { + require(className == expectedClassName, s"Error loading metadata: Expected class name" + + s" $expectedClassName but found class name $className") + } + + Metadata(className, uid, timestamp, sparkVersion, params, metadataStr) + } + + /** + * Extract Params from metadata, and set them in the instance. + * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + */ + def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + implicit val format = DefaultFormats + metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => val param = instance.getParam(paramName) @@ -213,8 +281,8 @@ private[ml] class DefaultParamsReader[T] extends Reader[T] { instance.set(param, value) } case _ => - throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataStr}.") } - instance.asInstanceOf[T] } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 0ec88ef77d695..6a3b20c88d2d2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,14 +17,11 @@ package org.apache.spark.mllib.api.python -import java.util.{List => JList} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} import org.apache.spark.mllib.clustering.GaussianMixtureModel +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * Wrapper around GaussianMixtureModel to provide helper methods in Python @@ -36,17 +33,11 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { /** * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian */ - val gaussians: JList[Object] = { - val modelGaussians = model.gaussians - var i = 0 - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - while (i < k) { - mu += modelGaussians(i).mu - sigma += modelGaussians(i).sigma - i += 1 + val gaussians: Array[Byte] = { + val modelGaussians = model.gaussians.map { gaussian => + Array[Any](gaussian.mu, gaussian.sigma) } - List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala new file mode 100644 index 0000000000000..63282eee6e656 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.api.python + +import scala.collection.JavaConverters + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.clustering.LDAModel +import org.apache.spark.mllib.linalg.Matrix + +/** + * Wrapper around LDAModel to provide helper methods in Python + */ +private[python] class LDAModelWrapper(model: LDAModel) { + + def topicsMatrix(): Matrix = model.topicsMatrix + + def vocabSize(): Int = model.vocabSize + + def describeTopics(): Array[Byte] = describeTopics(this.model.vocabSize) + + def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { + val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => + val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava + val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava + Array[Any](jTerms, jTermWeights) + } + SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava) + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 40c41806cdfea..54b03a9f90283 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -517,7 +517,7 @@ private[python] class PythonMLLibAPI extends Serializable { topicConcentration: Double, seed: java.lang.Long, checkpointInterval: Int, - optimizer: String): LDAModel = { + optimizer: String): LDAModelWrapper = { val algo = new LDA() .setK(k) .setMaxIterations(maxIterations) @@ -535,7 +535,16 @@ private[python] class PythonMLLibAPI extends Serializable { case _ => throw new IllegalArgumentException("input values contains invalid type value.") } } - algo.run(documents) + val model = algo.run(documents) + new LDAModelWrapper(model) + } + + /** + * Load a LDA model + */ + def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = { + val model = DistributedLDAModel.load(jsc.sc, path) + new LDAModelWrapper(model) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala new file mode 100644 index 0000000000000..29a7aa0bb63f2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -0,0 +1,491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @param k the desired number of leaf clusters (default: 4). The actual number could be smaller if + * there are no divisible leaf clusters. + * @param maxIterations the max number of k-means iterations to split clusters (default: 20) + * @param minDivisibleClusterSize the minimum number of points (if >= 1.0) or the minimum proportion + * of points (if < 1.0) of a divisible cluster (default: 1) + * @param seed a random seed (default: hash value of the class name) + * + * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000.]] + */ +@Since("1.6.0") +@Experimental +class BisectingKMeans private ( + private var k: Int, + private var maxIterations: Int, + private var minDivisibleClusterSize: Double, + private var seed: Long) extends Logging { + + import BisectingKMeans._ + + /** + * Constructs with the default configuration + */ + @Since("1.6.0") + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + + /** + * Sets the desired number of leaf clusters (default: 4). + * The actual number could be smaller if there are no divisible leaf clusters. + */ + @Since("1.6.0") + def setK(k: Int): this.type = { + require(k > 0, s"k must be positive but got $k.") + this.k = k + this + } + + /** + * Gets the desired number of leaf clusters. + */ + @Since("1.6.0") + def getK: Int = this.k + + /** + * Sets the max number of k-means iterations to split clusters (default: 20). + */ + @Since("1.6.0") + def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations > 0, s"maxIterations must be positive but got $maxIterations.") + this.maxIterations = maxIterations + this + } + + /** + * Gets the max number of k-means iterations to split clusters. + */ + @Since("1.6.0") + def getMaxIterations: Int = this.maxIterations + + /** + * Sets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster (default: 1). + */ + @Since("1.6.0") + def setMinDivisibleClusterSize(minDivisibleClusterSize: Double): this.type = { + require(minDivisibleClusterSize > 0.0, + s"minDivisibleClusterSize must be positive but got $minDivisibleClusterSize.") + this.minDivisibleClusterSize = minDivisibleClusterSize + this + } + + /** + * Gets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster. + */ + @Since("1.6.0") + def getMinDivisibleClusterSize: Double = minDivisibleClusterSize + + /** + * Sets the random seed (default: hash value of the class name). + */ + @Since("1.6.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Gets the random seed. + */ + @Since("1.6.0") + def getSeed: Long = this.seed + + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + if (input.getStorageLevel == StorageLevel.NONE) { + logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + + " its parent RDDs are also not cached.") + } + val d = input.map(_.size).first() + logInfo(s"Feature dimension: $d.") + // Compute and cache vector norms for fast distance computation. + val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) + val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } + var assignments = vectors.map(v => (ROOT_INDEX, v)) + var activeClusters = summarize(d, assignments) + val rootSummary = activeClusters(ROOT_INDEX) + val n = rootSummary.size + logInfo(s"Number of points: $n.") + logInfo(s"Initial cost: ${rootSummary.cost}.") + val minSize = if (minDivisibleClusterSize >= 1.0) { + math.ceil(minDivisibleClusterSize).toLong + } else { + math.ceil(minDivisibleClusterSize * n).toLong + } + logInfo(s"The minimum number of points of a divisible cluster is $minSize.") + var inactiveClusters = mutable.Seq.empty[(Long, ClusterSummary)] + val random = new Random(seed) + var numLeafClustersNeeded = k - 1 + var level = 1 + while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) { + // Divisible clusters are sufficiently large and have non-trivial cost. + var divisibleClusters = activeClusters.filter { case (_, summary) => + (summary.size >= minSize) && (summary.cost > MLUtils.EPSILON * summary.size) + } + // If we don't need all divisible clusters, take the larger ones. + if (divisibleClusters.size > numLeafClustersNeeded) { + divisibleClusters = divisibleClusters.toSeq.sortBy { case (_, summary) => + -summary.size + }.take(numLeafClustersNeeded) + .toMap + } + if (divisibleClusters.nonEmpty) { + val divisibleIndices = divisibleClusters.keys.toSet + logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") + var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => + val (left, right) = splitCenter(summary.center, random) + Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) + }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map + var newClusters: Map[Long, ClusterSummary] = null + var newAssignments: RDD[(Long, VectorWithNorm)] = null + for (iter <- 0 until maxIterations) { + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + .filter { case (index, _) => + divisibleIndices.contains(parentIndex(index)) + } + newClusters = summarize(d, newAssignments) + newClusterCenters = newClusters.mapValues(_.center).map(identity) + } + // TODO: Unpersist old indices. + val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + .persist(StorageLevel.MEMORY_AND_DISK) + assignments = indices.zip(vectors) + inactiveClusters ++= activeClusters + activeClusters = newClusters + numLeafClustersNeeded -= divisibleClusters.size + } else { + logInfo(s"None active and divisible clusters left on level $level. Stop iterations.") + inactiveClusters ++= activeClusters + activeClusters = Map.empty + } + level += 1 + } + val clusters = activeClusters ++ inactiveClusters + val root = buildTree(clusters) + new BisectingKMeansModel(root) + } + + /** + * Java-friendly version of [[run(RDD[Vector])*]] + */ + def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) +} + +private object BisectingKMeans extends Serializable { + + /** The index of the root node of a tree. */ + private val ROOT_INDEX: Long = 1 + + private val MAX_DIVISIBLE_CLUSTER_INDEX: Long = Long.MaxValue / 2 + + private val LEVEL_LIMIT = math.log10(Long.MaxValue) / math.log10(2) + + /** Returns the left child index of the given node index. */ + private def leftChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index.") + 2 * index + } + + /** Returns the right child index of the given node index. */ + private def rightChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index + 1.") + 2 * index + 1 + } + + /** Returns the parent index of the given node index, or 0 if the input is 1 (root). */ + private def parentIndex(index: Long): Long = { + index / 2 + } + + /** + * Summarizes data by each cluster as Map. + * @param d feature dimension + * @param assignments pairs of point and its cluster index + * @return a map from cluster indices to corresponding cluster summaries + */ + private def summarize( + d: Int, + assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + seqOp = (agg, v) => agg.add(v), + combOp = (agg1, agg2) => agg1.merge(agg2) + ).mapValues(_.summary) + .collect().toMap + } + + /** + * Cluster summary aggregator. + * @param d feature dimension + */ + private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private var n: Long = 0L + private val sum: Vector = Vectors.zeros(d) + private var sumSq: Double = 0.0 + + /** Adds a point. */ + def add(v: VectorWithNorm): this.type = { + n += 1L + // TODO: use a numerically stable approach to estimate cost + sumSq += v.norm * v.norm + BLAS.axpy(1.0, v.vector, sum) + this + } + + /** Merges another aggregator. */ + def merge(other: ClusterSummaryAggregator): this.type = { + n += other.n + sumSq += other.sumSq + BLAS.axpy(1.0, other.sum, sum) + this + } + + /** Returns the summary. */ + def summary: ClusterSummary = { + val mean = sum.copy + if (n > 0L) { + BLAS.scal(1.0 / n, mean) + } + val center = new VectorWithNorm(mean) + val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) + new ClusterSummary(n, center, cost) + } + } + + /** + * Bisects a cluster center. + * + * @param center current cluster center + * @param random a random number generator + * @return initial centers + */ + private def splitCenter( + center: VectorWithNorm, + random: Random): (VectorWithNorm, VectorWithNorm) = { + val d = center.vector.size + val norm = center.norm + val level = 1e-4 * norm + val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) + val left = center.vector.copy + BLAS.axpy(-level, noise, left) + val right = center.vector.copy + BLAS.axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * Updates assignments. + * @param assignments current assignments + * @param divisibleIndices divisible cluster indices + * @param newClusterCenters new cluster centers + * @return new assignments + */ + private def updateAssignments( + assignments: RDD[(Long, VectorWithNorm)], + divisibleIndices: Set[Long], + newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + assignments.map { case (index, v) => + if (divisibleIndices.contains(index)) { + val children = Seq(leftChildIndex(index), rightChildIndex(index)) + val selected = children.minBy { child => + KMeans.fastSquaredDistance(newClusterCenters(child), v) + } + (selected, v) + } else { + (index, v) + } + } + } + + /** + * Builds a clustering tree by re-indexing internal and leaf clusters. + * @param clusters a map from cluster indices to corresponding cluster summaries + * @return the root node of the clustering tree + */ + private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + var leafIndex = 0 + var internalIndex = -1 + + /** + * Builds a subtree from this given node index. + */ + def buildSubTree(rawIndex: Long): ClusteringTreeNode = { + val cluster = clusters(rawIndex) + val size = cluster.size + val center = cluster.center + val cost = cluster.cost + val isInternal = clusters.contains(leftChildIndex(rawIndex)) + if (isInternal) { + val index = internalIndex + internalIndex -= 1 + val leftIndex = leftChildIndex(rawIndex) + val rightIndex = rightChildIndex(rawIndex) + val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex => + KMeans.fastSquaredDistance(center, clusters(childIndex).center) + }.max) + val left = buildSubTree(leftIndex) + val right = buildSubTree(rightIndex) + new ClusteringTreeNode(index, size, center, cost, height, Array(left, right)) + } else { + val index = leafIndex + leafIndex += 1 + val height = 0.0 + new ClusteringTreeNode(index, size, center, cost, height, Array.empty) + } + } + + buildSubTree(ROOT_INDEX) + } + + /** + * Summary of a cluster. + * + * @param size the number of points within this cluster + * @param center the center of the points within this cluster + * @param cost the sum of squared distances to the center + */ + private case class ClusterSummary(size: Long, center: VectorWithNorm, cost: Double) +} + +/** + * Represents a node in a clustering tree. + * + * @param index node index, negative for internal nodes and non-negative for leaf nodes + * @param size size of the cluster + * @param centerWithNorm cluster center with norm + * @param cost cost of the cluster, i.e., the sum of squared distances to the center + * @param height height of the node in the dendrogram. Currently this is defined as the max distance + * from the center to the centers of the children's, but subject to change. + * @param children children nodes + */ +@Since("1.6.0") +@Experimental +class ClusteringTreeNode private[clustering] ( + val index: Int, + val size: Long, + private val centerWithNorm: VectorWithNorm, + val cost: Double, + val height: Double, + val children: Array[ClusteringTreeNode]) extends Serializable { + + /** Whether this is a leaf node. */ + val isLeaf: Boolean = children.isEmpty + + require((isLeaf && index >= 0) || (!isLeaf && index < 0)) + + /** Cluster center. */ + def center: Vector = centerWithNorm.vector + + /** Predicts the leaf cluster node index that the input point belongs to. */ + def predict(point: Vector): Int = { + val (index, _) = predict(new VectorWithNorm(point)) + index + } + + /** Returns the full prediction path from root to leaf. */ + def predictPath(point: Vector): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point)).toArray + } + + /** Returns the full prediction path from root to leaf. */ + private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + if (isLeaf) { + this :: Nil + } else { + val selected = children.minBy { child => + KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + } + selected :: selected.predictPath(pointWithNorm) + } + } + + /** + * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + */ + def computeCost(point: Vector): Double = { + val (_, cost) = predict(new VectorWithNorm(point)) + cost + } + + /** + * Predicts the cluster index and the cost of the input point. + */ + private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { + predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + } + + /** + * Predicts the cluster index and the cost of the input point. + * @param pointWithNorm input point + * @param cost the cost to the current center + * @return (predicted leaf cluster index, cost) + */ + private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + if (isLeaf) { + (index, cost) + } else { + val (selectedChild, minCost) = children.map { child => + (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + }.minBy(_._2) + selectedChild.predict(pointWithNorm, minCost) + } + } + + /** + * Returns all leaf nodes from this node. + */ + def leafNodes: Array[ClusteringTreeNode] = { + if (isLeaf) { + Array(this) + } else { + children.flatMap(_.leafNodes) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala new file mode 100644 index 0000000000000..5015f1540d920 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * Clustering model produced by [[BisectingKMeans]]. + * The prediction is done level-by-level from the root node to a leaf node, and at each node among + * its children the closest to the input point is selected. + * + * @param root the root node of the clustering tree + */ +@Since("1.6.0") +@Experimental +class BisectingKMeansModel @Since("1.6.0") ( + @Since("1.6.0") val root: ClusteringTreeNode + ) extends Serializable with Logging { + + /** + * Leaf cluster centers. + */ + @Since("1.6.0") + def clusterCenters: Array[Vector] = root.leafNodes.map(_.center) + + /** + * Number of leaf clusters. + */ + lazy val k: Int = clusterCenters.length + + /** + * Predicts the index of the cluster that the input point belongs to. + */ + @Since("1.6.0") + def predict(point: Vector): Int = { + root.predict(point) + } + + /** + * Predicts the indices of the clusters that the input points belong to. + */ + @Since("1.6.0") + def predict(points: RDD[Vector]): RDD[Int] = { + points.map { p => root.predict(p) } + } + + /** + * Java-friendly version of [[predict(RDD[Vector])*]] + */ + @Since("1.6.0") + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + + /** + * Computes the squared distance between the input point and the cluster center it belongs to. + */ + @Since("1.6.0") + def computeCost(point: Vector): Double = { + root.computeCost(point) + } + + /** + * Computes the sum of squared distances between the input points and their corresponding cluster + * centers. + */ + @Since("1.6.0") + def computeCost(data: RDD[Vector]): Double = { + data.map(root.computeCost).sum() + } + + /** + * Java-friendly version of [[computeCost(RDD[Vector])*]]. + */ + @Since("1.6.0") + def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 31d8a9fdea1c6..cd520f09bd466 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -183,8 +183,7 @@ abstract class LDAModel private[clustering] extends Saveable { /** * Local LDA model. * This model stores only the inferred topics. - * It may be used for computing topics for new documents, but it may give less accurate answers - * than the [[DistributedLDAModel]]. + * * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") @@ -353,7 +352,7 @@ class LocalLDAModel private[clustering] ( documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { - (id, Vectors.zeros(k)) + (id, Vectors.zeros(k)) } else { val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, @@ -366,6 +365,28 @@ class LocalLDAModel private[clustering] ( } } + /** Get a method usable as a UDF for [[topicDistributions()]] */ + private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbetaBc = sc.broadcast(expElogbeta) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + (termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + Vectors.zeros(k) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbetaBc.value, + docConcentrationBrz, + gammaShape, + k) + Vectors.dense(normalize(gamma, 1.0).toArray) + } + } + /** * Java-friendly version of [[topicDistributions]] */ @@ -477,8 +498,6 @@ object LocalLDAModel extends Loader[LocalLDAModel] { /** * Distributed LDA model. * This model stores the inferred topics, the full training dataset, and the topic distributions. - * When computing topics for new documents, it may give more accurate answers - * than the [[LocalLDAModel]]. */ @Since("1.3.0") class DistributedLDAModel private[clustering] ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index c5fdecd3ca17f..9267e6dbdb857 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -32,6 +32,7 @@ private[mllib] trait PMMLModelExport { @BeanProperty val pmml: PMML = new PMML + pmml.setVersion("4.2") setHeader(pmml) private def setHeader(pmml: PMML): Unit = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 78172843be56e..19a047ded257c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -37,15 +37,20 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Array[T]] = { + def sliding(windowSize: Int, step: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") - if (windowSize == 1) { + if (windowSize == 1 && step == 1) { self.map(Array(_)) } else { - new SlidingRDD[T](self, windowSize) + new SlidingRDD[T](self, windowSize, step) } } + /** + * [[sliding(Int, Int)*]] with step = 1. + */ + def sliding(windowSize: Int): RDD[Array[T]] = sliding(windowSize, 1) + /** * Reduces the elements of this RDD in a multi-level tree pattern. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 1facf83d806d0..ead8db6344998 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -24,13 +24,13 @@ import org.apache.spark.{TaskContext, Partition} import org.apache.spark.rdd.RDD private[mllib] -class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int) extends Partition with Serializable { override val index: Int = idx } /** - * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to @@ -40,19 +40,24 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] * * @param parent the parent RDD * @param windowSize the window size, must be greater than 1 + * @param step step size for windows * - * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + * @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]] + * @see [[scala.collection.IterableLike.sliding(Int, Int)*]] */ private[mllib] -class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int) extends RDD[Array[T]](parent) { - require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1), + "Window size and step must be greater than 0, " + + s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.") override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) - .sliding(windowSize) + .drop(part.offset) + .sliding(windowSize, step) .withPartial(false) .map(_.toArray) } @@ -62,40 +67,42 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int override def getPartitions: Array[Partition] = { val parentPartitions = parent.partitions - val n = parentPartitions.size + val n = parentPartitions.length if (n == 0) { Array.empty } else if (n == 1) { - Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0)) } else { - val n1 = n - 1 val w1 = windowSize - 1 - // Get the first w1 items of each partition, starting from the second partition. - val nextHeads = - parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n) - val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + // Get partition sizes and first w1 elements. + val (sizes, heads) = parent.mapPartitions { iter => + val w1Array = iter.take(w1).toArray + Iterator.single((w1Array.length + iter.length, w1Array)) + }.collect().unzip + val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]] var i = 0 + var cumSize = 0 var partitionIndex = 0 - while (i < n1) { - var j = i - val tail = mutable.ListBuffer[T]() - // Keep appending to the current tail until appended a head of size w1. - while (j < n1 && nextHeads(j).size < w1) { - tail ++= nextHeads(j) - j += 1 + while (i < n) { + val mod = cumSize % step + val offset = if (mod == 0) 0 else step - mod + val size = sizes(i) + if (offset < size) { + val tail = mutable.ListBuffer.empty[T] + // Keep appending to the current tail until it has w1 elements. + var j = i + 1 + while (j < n && tail.length < w1) { + tail ++= heads(j).take(w1 - tail.length) + j += 1 + } + if (sizes(i) + tail.length >= offset + windowSize) { + partitions += + new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset) + partitionIndex += 1 + } } - if (j < n1) { - tail ++= nextHeads(j) - j += 1 - } - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) - partitionIndex += 1 - // Skip appended heads. - i = j - } - // If the head of last partition has size w1, we also need to add this partition. - if (nextHeads.last.size == w1) { - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + cumSize += size + i += 1 } partitions.toArray } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 02309ce63219a..c407d98f1b795 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -53,6 +53,7 @@ public void regexTokenizer() { .setOutputCol("tokens") .setPattern("\\s") .setGaps(true) + .setToLowercase(false) .setMinTokenLength(3); diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index c39538014be81..01ff1ea658610 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -32,17 +32,23 @@ public class JavaDefaultReadWriteSuite { JavaSparkContext jsc = null; + SQLContext sqlContext = null; File tempDir = null; @Before public void setUp() { jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + SQLContext.clearActive(); + sqlContext = new SQLContext(jsc); + SQLContext.setActive(sqlContext); tempDir = Utils.createTempDir( System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); } @After public void tearDown() { + sqlContext = null; + SQLContext.clearActive(); if (jsc != null) { jsc.stop(); jsc = null; @@ -64,7 +70,6 @@ public void testDefaultReadWrite() throws IOException { } catch (IOException e) { // expected } - SQLContext sqlContext = new SQLContext(jsc); instance.write().context(sqlContext).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java new file mode 100644 index 0000000000000..a714620ff7e4b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaBisectingKMeansSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", this.getClass().getSimpleName()); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void twoDimensionalData() { + JavaRDD points = sc.parallelize(Lists.newArrayList( + Vectors.dense(4, -1), + Vectors.dense(4, 1), + Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + ), 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(2) + .setSeed(1L); + BisectingKMeansModel model = bkm.run(points); + Assert.assertEquals(3, model.k()); + Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child: model.root().children()) { + double[] center = child.center().toArray(); + if (center[0] > 2) { + Assert.assertEquals(2, child.size()); + Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + } else { + Assert.assertEquals(1, child.size()); + Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 325faf37e8eea..51b06b7eb6d53 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -31,7 +31,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ @@ -869,6 +870,18 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + } + test("read/write") { + // Set some Params to make sure set Params are serialized. + val lr = new LogisticRegression() + .setElasticNetParam(0.1) + .setMaxIter(2) + .fit(dataset) + val lr2 = testDefaultReadWrite(lr) + assert(lr.intercept === lr2.intercept) + assert(lr.coefficients.toArray === lr2.coefficients.toArray) + assert(lr.numClasses === lr2.numClasses) + assert(lr.numFeatures === lr2.numFeatures) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala new file mode 100644 index 0000000000000..b634d31cc34f0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +object LDASuite { + def generateLDAData( + sql: SQLContext, + rows: Int, + k: Int, + vocabSize: Int): DataFrame = { + val avgWC = 1 // average instances of each word in a doc + val sc = sql.sparkContext + val rng = new java.util.Random() + rng.setSeed(1) + val rdd = sc.parallelize(1 to rows).map { i => + Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) + }.map(v => new TestRow(v)) + sql.createDataFrame(rdd) + } +} + + +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { + + val k: Int = 5 + val vocabSize: Int = 30 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize) + } + + test("default parameters") { + val lda = new LDA() + + assert(lda.getFeaturesCol === "features") + assert(lda.getMaxIter === 20) + assert(lda.isDefined(lda.seed)) + assert(lda.getCheckpointInterval === 10) + assert(lda.getK === 10) + assert(!lda.isSet(lda.docConcentration)) + assert(!lda.isSet(lda.topicConcentration)) + assert(lda.getOptimizer === "online") + assert(lda.getLearningDecay === 0.51) + assert(lda.getLearningOffset === 1024) + assert(lda.getSubsamplingRate === 0.05) + assert(lda.getOptimizeDocConcentration) + assert(lda.getTopicDistributionCol === "topicDistribution") + } + + test("set parameters") { + val lda = new LDA() + .setFeaturesCol("test_feature") + .setMaxIter(33) + .setSeed(123) + .setCheckpointInterval(7) + .setK(9) + .setTopicConcentration(0.56) + .setTopicDistributionCol("myOutput") + + assert(lda.getFeaturesCol === "test_feature") + assert(lda.getMaxIter === 33) + assert(lda.getSeed === 123) + assert(lda.getCheckpointInterval === 7) + assert(lda.getK === 9) + assert(lda.getTopicConcentration === 0.56) + assert(lda.getTopicDistributionCol === "myOutput") + + + // setOptimizer + lda.setOptimizer("em") + assert(lda.getOptimizer === "em") + lda.setOptimizer("online") + assert(lda.getOptimizer === "online") + lda.setLearningDecay(0.53) + assert(lda.getLearningDecay === 0.53) + lda.setLearningOffset(1027) + assert(lda.getLearningOffset === 1027) + lda.setSubsamplingRate(0.06) + assert(lda.getSubsamplingRate === 0.06) + lda.setOptimizeDocConcentration(false) + assert(!lda.getOptimizeDocConcentration) + } + + test("parameters validation") { + val lda = new LDA() + + // misc Params + intercept[IllegalArgumentException] { + new LDA().setK(1) + } + intercept[IllegalArgumentException] { + new LDA().setOptimizer("no_such_optimizer") + } + intercept[IllegalArgumentException] { + new LDA().setDocConcentration(-1.1) + } + intercept[IllegalArgumentException] { + new LDA().setTopicConcentration(-1.1) + } + + // validateParams() + lda.validateParams() + lda.setDocConcentration(1.1) + lda.validateParams() + lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray) + lda.validateParams() + lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray) + withClue("LDA docConcentration validity check failed for bad array length") { + intercept[IllegalArgumentException] { + lda.validateParams() + } + } + + // Online LDA + intercept[IllegalArgumentException] { + new LDA().setLearningOffset(0) + } + intercept[IllegalArgumentException] { + new LDA().setLearningDecay(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(0) + } + intercept[IllegalArgumentException] { + new LDA().setSubsamplingRate(1.1) + } + } + + test("fit & transform with Online LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2) + val model = lda.fit(dataset) + + MLTestingUtils.checkCopy(model) + + assert(model.isInstanceOf[LocalLDAModel]) + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(!model.isDistributed) + + // transform() + val transformed = model.transform(dataset) + val expectedColumns = Array("features", lda.getTopicDistributionCol) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + transformed.select(lda.getTopicDistributionCol).collect().foreach { r => + val topicDistribution = r.getAs[Vector](0) + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + } + + // logLikelihood, logPerplexity + val ll = model.logLikelihood(dataset) + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPerplexity(dataset) + assert(lp >= 0.0 && lp != Double.PositiveInfinity) + + // describeTopics + val topics = model.describeTopics(3) + assert(topics.count() === k) + assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet) + topics.select("termIndices").collect().foreach { case r: Row => + val termIndices = r.getAs[Seq[Int]](0) + assert(termIndices.length === 3 && termIndices.toSet.size === 3) + } + topics.select("termWeights").collect().foreach { case r: Row => + val termWeights = r.getAs[Seq[Double]](0) + assert(termWeights.length === 3 && termWeights.forall(w => w >= 0.0 && w <= 1.0)) + } + } + + test("fit & transform with EM LDA") { + val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2) + val model_ = lda.fit(dataset) + + MLTestingUtils.checkCopy(model_) + + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + assert(model.vocabSize === vocabSize) + assert(model.estimatedDocConcentration.size === k) + assert(model.topicsMatrix.numRows === vocabSize) + assert(model.topicsMatrix.numCols === k) + assert(model.isDistributed) + + val localModel = model.toLocal + assert(localModel.isInstanceOf[LocalLDAModel]) + + // training logLikelihood, logPrior + val ll = model.trainingLogLikelihood + assert(ll <= 0.0 && ll != Double.NegativeInfinity) + val lp = model.logPrior + assert(lp <= 0.0 && lp != Double.NegativeInfinity) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index e5fd21c3f6fca..a02992a2407b3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -48,13 +48,13 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer0.setMinTokenLength(3) @@ -64,11 +64,23 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) testRegexTokenizer(tokenizer2, dataset2) } + + test("RegexTokenizer with toLowercase false") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setToLowercase(false) + val dataset = sqlContext.createDataFrame(Seq( + TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), + TokenizerTestData("java scala", Array("java", "scala")) + )) + testRegexTokenizer(tokenizer, dataset) + } } object RegexTokenizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 4545b0f281f5a..cac4bd9aa3ab8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -31,8 +31,9 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * Checks "overwrite" option and params. * @param instance ML instance to test saving/loading * @tparam T ML instance type + * @return Instance loaded from file */ - def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = { + def testDefaultReadWrite[T <: Params with Writable](instance: T): T = { val uid = instance.uid val path = new File(tempDir, uid).getPath @@ -61,6 +62,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => val load = instance.getClass.getMethod("load", classOf[String]) val another = load.invoke(instance, path).asInstanceOf[T] assert(another.uid === instance.uid) + another } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala new file mode 100644 index 0000000000000..41b9d5c0d93bb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("default values") { + val bkm0 = new BisectingKMeans() + assert(bkm0.getK === 4) + assert(bkm0.getMaxIterations === 20) + assert(bkm0.getMinDivisibleClusterSize === 1.0) + val bkm1 = new BisectingKMeans() + assert(bkm0.getSeed === bkm1.getSeed, "The default seed should be constant.") + } + + test("setter/getter") { + val bkm = new BisectingKMeans() + + val k = 10 + assert(bkm.getK !== k) + assert(bkm.setK(k).getK === k) + val maxIter = 100 + assert(bkm.getMaxIterations !== maxIter) + assert(bkm.setMaxIterations(maxIter).getMaxIterations === maxIter) + val minSize = 2.0 + assert(bkm.getMinDivisibleClusterSize !== minSize) + assert(bkm.setMinDivisibleClusterSize(minSize).getMinDivisibleClusterSize === minSize) + val seed = 10L + assert(bkm.getSeed !== seed) + assert(bkm.setSeed(seed).getSeed === seed) + + intercept[IllegalArgumentException] { + bkm.setK(0) + } + intercept[IllegalArgumentException] { + bkm.setMaxIterations(0) + } + intercept[IllegalArgumentException] { + bkm.setMinDivisibleClusterSize(0.0) + } + } + + test("1D data") { + val points = Vectors.sparse(1, Array.empty, Array.empty) +: + (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(1) + .setSeed(1L) + // The clusters should be + // (0, 1, 2, 3, 4, 5, 6, 7) + // - (0, 1, 2, 3) + // - (0, 1) + // - (2, 3) + // - (4, 5, 6, 7) + // - (4, 5) + // - (6, 7) + val model = bkm.run(data) + assert(model.k === 4) + // The total cost should be 8 * 0.5 * 0.5 = 2.0. + assert(model.computeCost(data) ~== 2.0 relTol 1e-12) + val predictions = data.map(v => (v(0), model.predict(v))).collectAsMap() + Range(0, 8, 2).foreach { i => + assert(predictions(i) === predictions(i + 1), + s"$i and ${i + 1} should belong to the same cluster.") + } + val root = model.root + assert(root.center(0) ~== 3.5 relTol 1e-12) + assert(root.height ~== 2.0 relTol 1e-12) + assert(root.children.length === 2) + assert(root.children(0).height ~== 1.0 relTol 1e-12) + assert(root.children(1).height ~== 1.0 relTol 1e-12) + } + + test("points are the same") { + val data = sc.parallelize(Seq.fill(8)(Vectors.dense(1.0, 1.0)), 2) + val bkm = new BisectingKMeans() + .setK(2) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 1) + } + + test("more desired clusters than points") { + val data = sc.parallelize(Seq.tabulate(4)(i => Vectors.dense(i)), 2) + val bkm = new BisectingKMeans() + .setK(8) + .setMaxIterations(2) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 4) + } + + test("min divisible cluster") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMinDivisibleClusterSize(10) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + + bkm.setMinDivisibleClusterSize(0.5) + val sameModel = bkm.run(data) + assert(sameModel.k === 3) + } + + test("larger clusters get selected first") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + } + + test("2D data") { + val points = Seq( + (11, 10), (9, 10), (10, 9), (10, 11), + (11, -10), (9, -10), (10, -9), (10, -11), + (0, 1), (0, -1) + ).map { case (x, y) => + if (x == 0) { + Vectors.sparse(2, Array(1), Array(y)) + } else { + Vectors.dense(x, y) + } + } + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(4) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.root.center ~== Vectors.dense(8, 0) relTol 1e-12) + model.root.leafNodes.foreach { node => + if (node.center(0) < 5) { + assert(node.size === 2) + assert(node.center ~== Vectors.dense(0, 0) relTol 1e-12) + } else if (node.center(1) > 0) { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, 10) relTol 1e-12) + } else { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, -10) relTol 1e-12) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index bc64172614830..ac93733bab5f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -28,9 +28,12 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { for (numPartitions <- 1 to 8) { val rdd = sc.parallelize(data, numPartitions) for (windowSize <- 1 to 6) { - val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList - val expected = data.sliding(windowSize).map(_.toList).toList - assert(sliding === expected) + for (step <- 1 to 3) { + val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList + val expected = data.sliding(windowSize, step) + .map(_.toList).toList.filter(l => l.size == windowSize) + assert(sliding === expected) + } } assert(rdd.sliding(7).collect().isEmpty, "Should return an empty RDD if the window size is greater than the number of items.") @@ -40,7 +43,7 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding with empty partitions") { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) - assert(rdd.partitions.size === data.length) + assert(rdd.partitions.length === data.length) val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 49aff21fe7914..14152cdd63bc7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} -import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 5d1796ef65722..378139593b26f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -32,11 +32,14 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => .setMaster("local[2]") .setAppName("MLlibUnitTest") sc = new SparkContext(conf) + SQLContext.clearActive() sqlContext = new SQLContext(sc) + SQLContext.setActive(sqlContext) } override def afterAll() { sqlContext = null + SQLContext.clearActive() if (sc != null) { sc.stop() } diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 43900e6f2c972..1b64b863a9fe5 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -59,15 +59,24 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; + private final boolean closeIdleConnections; private final MessageEncoder encoder; private final MessageDecoder decoder; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this(conf, rpcHandler, false); + } + + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); + this.closeIdleConnections = closeIdleConnections; } /** @@ -144,7 +153,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs()); + conf.connectionTimeoutMs(), closeIdleConnections); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4952ffb44bb8b..42a4f664e697c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -158,6 +158,16 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } } + /** + * Create a completely new {@link TransportClient} to the given remote host / port + * But this connection is not pooled. + */ + public TransportClient createUnmanagedClient(String remoteHost, int remotePort) + throws IOException { + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + return createClient(address); + } + /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 8e0ee709e38e3..f8fcd1c3d7d76 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -55,16 +55,19 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + // there's no race between the idle timeout and incrementing the numOutstandingRequests + // (see SPARK-7003). boolean isActuallyOverdue = System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + ctx.close(); + } } } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 35de5e57ccb98..f447137419306 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +38,7 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.ConfigProvider; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -177,4 +179,36 @@ public void closeBlockClientsWithFactory() throws IOException { assertFalse(c1.isActive()); assertFalse(c2.isActive()); } + + @Test + public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { + TransportConf conf = new TransportConf(new ConfigProvider() { + + @Override + public String get(String name) { + if ("spark.shuffle.io.connectionTimeout".equals(name)) { + // We should make sure there is enough time for us to observe the channel is active + return "1s"; + } + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } + }); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportClientFactory factory = context.createClientFactory(); + try { + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + } finally { + factory.close(); + } + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ea6d248d66be3..ef3a9dcc8711f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -78,7 +78,7 @@ protected void checkInit() { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List bootstraps = Lists.newArrayList(); if (saslEnabled) { bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); @@ -137,9 +137,13 @@ public void registerWithShuffleServer( String execId, ExecutorShuffleInfo executorInfo) throws IOException { checkInit(); - TransportClient client = clientFactory.createClient(host, port); - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + TransportClient client = clientFactory.createUnmanagedClient(host, port); + try { + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } finally { + client.close(); + } } @Override diff --git a/pom.xml b/pom.xml index c965bcb0056a3..3853ce78cbb42 100644 --- a/pom.xml +++ b/pom.xml @@ -98,6 +98,7 @@ sql/catalyst sql/core sql/hive + docker-integration-tests unsafe assembly external/twitter @@ -154,6 +155,8 @@ 0.7.1 1.9.40 1.4.0 + + 0.10.1 4.3.2 @@ -390,6 +393,14 @@ + + + org.apache.xbean + xbean-asm5-shaded + 4.4 + diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 004941d5f50ae..3d2d235a00c93 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -23,15 +23,14 @@ import java.net.{HttpURLConnection, URI, URL, URLEncoder} import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils import org.apache.spark.util.ParentClassLoader -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ - /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, * used to load classes defined by the interpreter when the REPL is used. @@ -192,7 +191,7 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader } class ConstructorCleaner(className: String, cv: ClassVisitor) -extends ClassVisitor(ASM4, cv) { +extends ClassVisitor(ASM5, cv) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { val mv = cv.visitMethod(access, name, desc, sig, exceptions) @@ -202,7 +201,7 @@ extends ClassVisitor(ASM4, cv) { // field in the class to point to it, but do nothing otherwise. mv.visitCode() mv.visitVarInsn(ALOAD, 0) // load this - mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V") + mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) mv.visitVarInsn(ALOAD, 0) // load this // val classType = className.replace('.', '/') // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") diff --git a/sbin/start-master.sh b/sbin/start-master.sh index c20e19a8412df..9f2e14dff609f 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -23,6 +23,20 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.master.Master" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-master.sh [options]" + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + ORIGINAL_ARGS="$@" START_TACHYON=false @@ -30,7 +44,7 @@ START_TACHYON=false while (( "$#" )); do case $1 in --with-tachyon) - if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + if [ ! -e "${SPARK_HOME}"/tachyon/bin/tachyon ]; then echo "Error: --with-tachyon specified, but tachyon not found." exit -1 fi @@ -56,12 +70,12 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ +"${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon format -s - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon-start.sh master + "${SPARK_HOME}"/tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "${SPARK_HOME}"/tachyon/bin/tachyon format -s + "${SPARK_HOME}"/tachyon/bin/tachyon-start.sh master fi diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 21455648d1c6d..8c268b8859155 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -31,18 +31,24 @@ # worker. Subsequent workers will increment this # number. Default is 8081. -usage="Usage: start-slave.sh where is like spark://localhost:7077" - -if [ $# -lt 1 ]; then - echo $usage - echo Called as start-slave.sh $* - exit 1 -fi - if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.worker.Worker" + +if [[ $# -lt 1 ]] || [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-slave.sh [options] " + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -72,7 +78,7 @@ function start_instance { fi WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) - "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS $WORKER_NUM \ --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 64a0c71bbef2a..050c3f360476f 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -150,6 +150,13 @@ This file is divided into 3 sections: // scalastyle:on println]]> + + @VisibleForTesting + + + Class\.forName sort(Iterator inputIterator) throws IOExce return sort(); } - /** - * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. - */ - public static boolean supportsSchema(StructType schema) { - return UnsafeProjection.canSupport(schema); - } - private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala new file mode 100644 index 0000000000000..c8b017e251637 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.types.StructType + +/** + * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. + * + * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking + * and reuse internal buffers to improve performance. + */ +trait Encoder[T] extends Serializable { + + /** Returns the schema of encoding this type of object as a Row. */ + def schema: StructType + + /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ + def clsTag: ClassTag[T] +} + +object Encoders { + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + + def tuple[T1, T2]( + e1: Encoder[T1], + e2: Encoder[T2]): Encoder[(T1, T2)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) + } + + def tuple[T1, T2, T3]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) + } + + def tuple[T1, T2, T3, T4]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) + } + + def tuple[T1, T2, T3, T4, T5]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + ExpressionEncoder.tuple( + encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ed2fdf9f2f7cf..0f0f200122c34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -152,7 +152,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def apply(i: Int): Any = get(i) @@ -177,7 +177,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def get(i: Int): Any @@ -306,7 +306,15 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getStruct(i: Int): Row = getAs[Row](i) + def getStruct(i: Int): Row = { + // Product and Row both are recoginized as StructType in a Row + val t = get(i) + if (t.isInstanceOf[Product]) { + Row.fromTuple(t.asInstanceOf[Product]) + } else { + t.asInstanceOf[Row] + } + } /** * Returns the value at position i. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 3f351b07b37df..7c2b8a9407884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean + + protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -29,7 +31,13 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } + + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { + throw new UnsupportedOperationException + } } /** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf +case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0b8a8abd02d67..0b3dd351e38e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -75,7 +75,7 @@ trait ScalaReflection { * * @see SPARK-5281 */ - private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping @@ -153,18 +153,18 @@ trait ScalaReflection { */ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) - protected def constructorFor( + private def constructorFor( tpe: `Type`, path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { /** Returns the current path with a sub-field extracted. */ - def addToPath(part: String) = + def addToPath(part: String): Expression = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) /** Returns the current path with a field at ordinal extracted. */ - def addToPathOrdinal(ordinal: Int, dataType: DataType) = + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index cd717c09f8e5e..2a132d8b82bef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -22,6 +22,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DataTypeParser @@ -272,7 +273,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val function: Parser[Expression] = ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => if (lexical.normalizeKeyword(udfName) == "count") { - Count(Literal(1)) + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid expression $udfName(*)") } @@ -281,14 +282,14 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { - case "sum" => SumDistinct(exprs.head) - case "count" => CountDistinct(exprs) + case "count" => + aggregate.Count(exprs).toAggregateExpression(isDistinct = true) case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) } } | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp) + AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate $udfName") } @@ -296,7 +297,10 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp, s.toDouble) + AggregateExpression( + HyperLogLogPlusPlus(exp, s.toDouble, 0, 0), + mode = Complete, + isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate($s) $udfName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 899ee67352df4..2f4670b55bdba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -72,6 +72,7 @@ class Analyzer( ResolveRelations :: ResolveReferences :: ResolveGroupingAnalytics :: + ResolvePivot :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -79,6 +80,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -165,6 +167,10 @@ class Analyzer( case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.withNewAggs(assignAliases(g.aggregations)) + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) + if child.resolved && hasUnresolvedAlias(groupByExprs) => + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) } @@ -247,6 +253,43 @@ class Analyzer( } } + object ResolvePivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Pivot if !p.childrenResolved => p + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + val singleAgg = aggregates.size == 1 + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) + } + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString + Alias(filteredAggregate, name)() + } + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ @@ -525,21 +568,14 @@ class Analyzer( case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) + // This function is not an aggregate function, just return the resolved one. case other => other } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98d6637c0601b..7b2c93d63d673 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -108,7 +109,23 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK + case aggExpr: AggregateExpression => + aggExpr.aggregateFunction.children.foreach { child => + child.foreach { + case agg: AggregateExpression => + failAnalysis( + s"It is not allowed to use an aggregate function in the argument of " + + s"another aggregate function. Please use the inner aggregate function " + + s"in a sub-query.") + case other => // OK + } + + if (!child.deterministic) { + failAnalysis( + s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in the arguments of an aggregate function.") + } + } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + @@ -120,14 +137,22 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK + def checkValidGroupingExprs(expr: Expression): Unit = { + // Check if the data type of expr is orderable. + if (!RowOrdering.isOrderable(expr.dataType)) { + failAnalysis( + s"expression ${expr.prettyString} cannot be used as a grouping expression " + + s"because its data type ${expr.dataType.simpleString} is not a orderable " + + s"data type.") + } + + if (!expr.deterministic) { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in grouping expression.") + } } aggregateExprs.foreach(checkValidAggregateExpression) @@ -179,7 +204,8 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] => + // The rule above is used to check Aggregate operator. failAnalysis( s"""nondeterministic expressions are only allowed in Project or Filter, found: | ${o.expressions.map(_.prettyString).mkString(",")} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala new file mode 100644 index 0000000000000..c0c960471a61a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.IntegerType + +/** + * This rule rewrites an aggregate query with distinct aggregations into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.registerTempTable("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns for the the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression cannocalization + * techniques. + */ +case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p + // We need to wait until this Aggregate operator is resolved. + case a: Aggregate => rewrite(a) + case p => p + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { + // When the flag is set to specialize single distinct agg planning, + // we will rely on our Aggregation strategy to handle queries with a single + // distinct column and this aggregate operator does have grouping expressions. + distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + } else { + distinctAggGroups.size >= 1 + } + if (shouldRewrite) { + // Create the attributes for the grouping id and the group by clause. + val gid = new AttributeReference("gid", IntegerType, false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Expression): AggregateFunction = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) + }).asInstanceOf[AggregateFunction] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrLookup(x)) + } + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result + } + + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.prettyString, e.dataType, true)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d4334d16289a5..a8f4d257acd0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -177,6 +178,7 @@ object FunctionRegistry { expression[ToRadians]("radians"), // aggregate functions + expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), @@ -260,6 +262,7 @@ object FunctionRegistry { expression[Quarter]("quarter"), expression[Second]("second"), expression[ToDate]("to_date"), + expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), @@ -278,7 +281,8 @@ object FunctionRegistry { expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), - expression[InputFileName]("input_file_name") + expression[InputFileName]("input_file_name"), + expression[MonotonicallyIncreasingID]("monotonically_increasing_id") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 84e2b1366f626..92188ee54fd28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -295,14 +296,19 @@ object HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) - case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) - case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) - case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) - case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) - case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) + case StddevPop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevPop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case StddevSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + StddevSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) } } @@ -562,12 +568,6 @@ object HiveTypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => - SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => - SumDistinct(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. case Average(e @ IntegralType()) if e.dataType != LongType => Average(Cast(e, LongType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index eae17c86ddc7a..6485bdfb30234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -141,6 +141,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false + override def prettyString: String = { + s"${name}(${children.map(_.prettyString).mkString(",")})" + } + override def toString: String = s"'$name(${children.mkString(",")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d8df66430a695..af594c25c54cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -23,6 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.types._ @@ -144,17 +145,18 @@ package object dsl { } } - def sum(e: Expression): Expression = Sum(e) - def sumDistinct(e: Expression): Expression = SumDistinct(e) - def count(e: Expression): Expression = Count(e) - def countDistinct(e: Expression*): Expression = CountDistinct(e) + def sum(e: Expression): Expression = Sum(e).toAggregateExpression() + def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) + def count(e: Expression): Expression = Count(e).toAggregateExpression() + def countDistinct(e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = - ApproxCountDistinct(e, rsd) - def avg(e: Expression): Expression = Average(e) - def first(e: Expression): Expression = First(e) - def last(e: Expression): Expression = Last(e) - def min(e: Expression): Expression = Min(e) - def max(e: Expression): Expression = Max(e) + HyperLogLogPlusPlus(e, rsd).toAggregateExpression() + def avg(e: Expression): Expression = Average(e).toAggregateExpression() + def first(e: Expression): Expression = new First(e).toAggregateExpression() + def last(e: Expression): Expression = new Last(e).toAggregateExpression() + def min(e: Expression): Expression = Min(e).toAggregateExpression() + def max(e: Expression): Expression = Max(e).toAggregateExpression() def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala deleted file mode 100644 index f05e18288de2b..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} -import org.apache.spark.sql.catalyst.expressions._ - -/** - * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. - * - * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking - * and reuse internal buffers to improve performance. - */ -trait Encoder[T] extends Serializable { - - /** Returns the schema of encoding this type of object as a Row. */ - def schema: StructType - - /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ - def clsTag: ClassTag[T] -} - -object Encoder { - import scala.reflect.runtime.universe._ - - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) - - def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { - tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2)]] - } - - def tuple[T1, T2, T3]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { - tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] - } - - def tuple[T1, T2, T3, T4]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { - tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] - } - - def tuple[T1, T2, T3, T4, T5]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4], - enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { - tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] - } - - private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - assert(encoders.length > 1) - // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. - assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) - - val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) - }) - - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - - val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) - }.zipWithIndex.map { case (expr, index) => - expr.transformUp { - case BoundReference(0, t: ObjectType, _) => - Invoke( - BoundReference(0, ObjectType(cls), true), - s"_${index + 1}", - t) - } - } - - val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => - if (enc.flat) { - enc.constructExpression.transform { - case b: BoundReference => b.copy(ordinal = index) - } - } else { - enc.constructExpression.transformUp { - case BoundReference(ordinal, dt, _) => - GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) - } - } - } - - val constructExpression = - NewInstance(cls, constructExpressions, false, ObjectType(cls)) - - new ExpressionEncoder[Any]( - schema, - false, - extractExpressions, - constructExpression, - ClassTag.apply(cls)) - } - - - def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] - - private def getTypeTag[T](c: Class[T]): TypeTag[T] = { - import scala.reflect.api - - // val mirror = runtimeMirror(c.getClassLoader) - val mirror = rootMirror - val sym = mirror.staticClass(c.getName) - val tpe = sym.selfType - TypeTag(mirror, new api.TypeCreator { - def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = - if (m eq mirror) tpe.asInstanceOf[U # Type] - else throw new IllegalArgumentException( - s"Type tag defined in $mirror cannot be migrated to other mirrors.") - }) - } - - def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - ExpressionEncoder[(T1, T2)]() - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index c287aebeeee05..9a1a8f5cbbdc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.catalyst.encoders -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.util.Utils - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} +import org.apache.spark.util.Utils +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} +import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitves to and from the @@ -61,69 +61,123 @@ object ExpressionEncoder { /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an - * N-tuple. Note that these encoders should first be bound correctly to the combined input - * schema. + * N-tuple. Note that these encoders should be unresolved so that information about + * name/positional binding is preserved. */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - val schema = - StructType( - encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + encoders.foreach(_.assertUnresolved()) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) + + val toRowExpressions = encoders.map { + case e if e.flat => e.toRowExpressions.head + case other => CreateStruct(other.toRowExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t, _) => + Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + t) + } } - val constructExpression = - NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + + val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.fromRowExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + val input = BoundReference(index, enc.schema, nullable = true) + enc.fromRowExpression.transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(input, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt) + } + } + } + + val fromRowExpression = + NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls)) new ExpressionEncoder[Any]( schema, - false, - extractExpressions, - constructExpression, - ClassTag.apply(cls)) + flat = false, + toRowExpressions, + fromRowExpression, + ClassTag(cls)) } - /** A helper for producing encoders of Tuple2 from other encoders. */ def tuple[T1, T2]( e1: ExpressionEncoder[T1], e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = - tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] + tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]] + + def tuple[T1, T2, T3]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] = + tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + + def tuple[T1, T2, T3, T4]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] = + tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + + def tuple[T1, T2, T3, T4, T5]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2], + e3: ExpressionEncoder[T3], + e4: ExpressionEncoder[T4], + e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] = + tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] } /** * A generic encoder for JVM objects. * * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. + * @param toRowExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param fromRowExpression An expression that will construct an object given an [[InternalRow]]. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( schema: StructType, flat: Boolean, - extractExpressions: Seq[Expression], - constructExpression: Expression, + toRowExpressions: Seq[Expression], + fromRowExpression: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(extractExpressions.size == 1) + if (flat) require(toRowExpressions.size == 1) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) private val inputRow = new GenericMutableRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should * copy the result before making another call if required. */ - def toRow(t: T): InternalRow = { + def toRow(t: T): InternalRow = try { inputRow(0) = t extractProjection(inputRow) + } catch { + case e: Exception => + throw new RuntimeException( + s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e) } /** @@ -135,7 +189,20 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) + throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e) + } + + /** + * The process of resolution to a given schema throws away information about where a given field + * is being bound by ordinal instead of by name. This method checks to make sure this process + * has not been done already in places where we plan to do later composition of encoders. + */ + def assertUnresolved(): Unit = { + (fromRowExpression +: toRowExpressions).foreach(_.foreach { + case a: AttributeReference => + sys.error(s"Unresolved encoder expected, but $a was found.") + case _ => + }) } /** @@ -143,9 +210,14 @@ case class ExpressionEncoder[T]( * given schema. */ def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) + val positionToAttribute = AttributeMap.toIndex(schema) + val unbound = fromRowExpression transform { + case b: BoundReference => positionToAttribute(b.ordinal) + } + + val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) - copy(constructExpression = analyzedPlan.expressions.head.children.head) + copy(fromRowExpression = analyzedPlan.expressions.head.children.head) } /** @@ -154,55 +226,19 @@ case class ExpressionEncoder[T]( * resolve before bind. */ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) - } - - /** - * Replaces any bound references in the schema with the attributes at the corresponding ordinal - * in the provided schema. This can be used to "relocate" a given encoder to pull values from - * a different schema than it was initially bound to. It can also be used to assign attributes - * to ordinal based extraction (i.e. because the input data was a tuple). - */ - def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) - copy(constructExpression = constructExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) - }) - } - - /** - * Given an encoder that has already been bound to a given schema, returns a new encoder - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were originally part of a larger - * row, but now you have projected out only the key expressions. - */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) + copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema)) } /** - * Returns a copy of this encoder where the expressions used to create an object given an - * input row have been modified to pull the object out from a nested struct, instead of the - * top level fields. + * Returns a new encoder with input columns shifted by `delta` ordinals */ - def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { - copy(constructExpression = constructExpression transform { - case u: Attribute if u != input => - UnresolvedExtractValue(input, Literal(u.name)) - case b: BoundReference if b != input => - GetStructField( - input, - StructField(s"i[${b.ordinal}]", b.dataType), - b.ordinal) + def shift(delta: Int): ExpressionEncoder[T] = { + copy(fromRowExpression = fromRowExpression transform { + case r: BoundReference => r.copy(ordinal = r.ordinal + delta) }) } - protected val attrs = extractExpressions.flatMap(_.collect { + protected val attrs = toRowExpressions.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" case b: BoundReference => s"[${b.ordinal}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala new file mode 100644 index 0000000000000..6d307ab13a9fc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} +import org.apache.spark.sql.catalyst.ScalaReflection + +object FlatEncoder { + import ScalaReflection.schemaFor + import ScalaReflection.dataTypeFor + + def apply[T : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) + + val input = BoundReference(0, dataTypeFor(tpe), nullable = true) + val toRowExpression = CreateNamedStruct( + Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) + val fromRowExpression = ProductEncoder.constructorFor(tpe) + + new ExpressionEncoder[T]( + toRowExpression.dataType, + flat = true, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](cls)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala new file mode 100644 index 0000000000000..414adb21168ed --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.util.Utils +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} + +import scala.reflect.ClassTag + +object ProductEncoder { + import ScalaReflection.universe._ + import ScalaReflection.localTypeOf + import ScalaReflection.dataTypeFor + import ScalaReflection.Schema + import ScalaReflection.schemaFor + import ScalaReflection.arrayClassFor + + def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] + val fromRowExpression = constructorFor(tpe) + + new ExpressionEncoder[T]( + toRowExpression.dataType, + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](cls)) + } + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + def extractorFor( + inputObject: Expression, + tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + optType match { + // For primitive types we must manually unbox the value of the object. + case t if t <:< definitions.IntTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), + "intValue", + IntegerType) + case t if t <:< definitions.LongTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), + "longValue", + LongType) + case t if t <:< definitions.DoubleTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), + "doubleValue", + DoubleType) + case t if t <:< definitions.FloatTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), + "floatValue", + FloatType) + case t if t <:< definitions.ShortTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), + "shortValue", + ShortType) + case t if t <:< definitions.ByteTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), + "byteValue", + ByteType) + case t if t <:< definitions.BooleanTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), + "booleanValue", + BooleanType) + + // For non-primitives, we can just extract the object from the Option and then recurse. + case other => + val className: String = optType.erasure.typeSymbol.asClass.fullName + val classObj = Utils.classForName(className) + val optionObjectType = ObjectType(classObj) + + val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, schemaFor(optType).dataType), + extractorFor(unwrapped, optType)) + } + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + CreateNamedStruct(params.head.flatMap { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case other => + throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + } + } + } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (RowEncoder.isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } + + def constructorFor( + tpe: `Type`, + path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { + + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + + tpe match { + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + WrapOption(null, constructorFor(optType, path)) + + case t if t <:< localTypeOf[java.lang.Integer] => + val boxedType = classOf[java.lang.Integer] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Long] => + val boxedType = classOf[java.lang.Long] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Double] => + val boxedType = classOf[java.lang.Double] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Float] => + val boxedType = classOf[java.lang.Float] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Short] => + val boxedType = classOf[java.lang.Short] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Byte] => + val boxedType = classOf[java.lang.Byte] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Boolean] => + val boxedType = classOf[java.lang.Boolean] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + arrayClassFor(elementType)) + } + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + val className: String = t.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + + val arguments = params.head.zipWithIndex.map { case (p, i) => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val dataType = schemaFor(fieldType).dataType + + // For tuples, we based grab the inner fields by ordinal instead of name. + if (className startsWith "scala.Tuple") { + constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + } else { + constructorFor(fieldType, Some(addToPath(fieldName))) + } + } + + val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0b42130a013b2..9bb1602494b68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -50,6 +50,14 @@ object RowEncoder { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case TimestampType => StaticInvoke( DateTimeUtils, @@ -109,19 +117,32 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => + val method = if (f.dataType.isInstanceOf[StructType]) { + "getStruct" + } else { + "get" + } If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), extractorsFor( - Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil), + Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), f.dataType)) } CreateStruct(convertedFields) } - private def externalDataTypeFor(dt: DataType): DataType = dt match { + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => dt + FloatType | DoubleType | BinaryType => true + case _ => false + } + + private def externalDataTypeFor(dt: DataType): DataType = dt match { + case _ if isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -129,6 +150,7 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case udt: UserDefinedType[_] => ObjectType(udt.userClass) } private def constructorFor(schema: StructType): Expression = { @@ -137,16 +159,24 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) + constructorFor(BoundReference(i, f.dataType, f.nullable)) ) } CreateExternalRow(fields) } - private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { + private def constructorFor(input: Expression): Expression = input.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + case TimestampType => StaticInvoke( DateTimeUtils, @@ -170,7 +200,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_, et), input, et), + MapObjects(constructorFor, input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( @@ -181,10 +211,10 @@ object RowEncoder { case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType) + val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType) + val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( ArrayBasedMapData, @@ -197,42 +227,8 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(getField(input, i, f.dataType), f.dataType)) + constructorFor(GetInternalRowField(input, i, f.dataType))) } CreateExternalRow(convertedFields) } - - private def getField( - row: Expression, - ordinal: Int, - dataType: DataType): Expression = dataType match { - case BooleanType => - Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil) - case ByteType => - Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil) - case ShortType => - Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil) - case IntegerType | DateType => - Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil) - case LongType | TimestampType => - Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil) - case FloatType => - Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil) - case DoubleType => - Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil) - case t: DecimalType => - Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_))) - case StringType => - Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil) - case BinaryType => - Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil) - case CalendarIntervalType => - Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil) - case t: StructType => - Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil) - case _: ArrayType => - Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil) - case _: MapType => - Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index d4642a500672e..9e283f5eb6342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -17,10 +17,20 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference + package object encoders { + /** + * Returns an internal encoder object that can be used to serialize / deserialize JVM objects + * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute + * references from a specific schema.) This requirement allows us to preserve whether a given + * object type is being bound by name or by ordinal when doing resolution. + */ private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { - case e: ExpressionEncoder[A] => e + case e: ExpressionEncoder[A] => + e.assertUnresolved() + e case _ => sys.error(s"Only expression encoders are supported today") } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala new file mode 100644 index 0000000000000..f7162e420d19a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +/** + * This class is used to compute equality of (sub)expression trees. Expressions can be added + * to this class and they subsequently query for expression equality. Expression trees are + * considered equal if for the same input(s), the same result is produced. + */ +class EquivalentExpressions { + /** + * Wrapper around an Expression that provides semantic equality. + */ + case class Expr(e: Expression) { + override def equals(o: Any): Boolean = o match { + case other: Expr => e.semanticEquals(other.e) + case _ => false + } + override val hashCode: Int = e.semanticHash() + } + + // For each expression, the set of equivalent expressions. + private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.MutableList[Expression]] + + /** + * Adds each expression to this data structure, grouping them with existing equivalent + * expressions. Non-recursive. + * Returns true if there was already a matching expression. + */ + def addExpr(expr: Expression): Boolean = { + if (expr.deterministic) { + val e: Expr = Expr(expr) + val f = equivalenceMap.get(e) + if (f.isDefined) { + f.get += expr + true + } else { + equivalenceMap.put(e, mutable.MutableList(expr)) + false + } + } else { + false + } + } + + /** + * Adds the expression to this data structure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + * If ignoreLeaf is true, leaf nodes are ignored. + */ + def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + if (!skip && !addExpr(root)) { + root.children.foreach(addExprTree(_, ignoreLeaf)) + } + } + + /** + * Returns all of the expression trees that are equivalent to `e`. Returns + * an empty collection if there are none. + */ + def getEquivalentExprs(e: Expression): Seq[Expression] = { + equivalenceMap.getOrElse(Expr(e), mutable.MutableList()) + } + + /** + * Returns all the equivalent sets of expressions. + */ + def getAllEquivalentExprs: Seq[Seq[Expression]] = { + equivalenceMap.values.map(_.toSeq).toSeq + } + + /** + * Returns the state of the data structure as a string. If `all` is false, skips sets of + * equivalent expressions with cardinality 1. + */ + def debugString(all: Boolean = false): String = { + val sb: mutable.StringBuilder = new StringBuilder() + sb.append("Equivalent expressions:\n") + equivalenceMap.foreach { case (k, v) => { + if (all || v.length > 1) { + sb.append(" " + v.mkString(", ")).append("\n") + } + }} + sb.toString() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 96fcc799e537a..540ed3500616a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -92,12 +92,23 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) - ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code) + ctx.subExprEliminationExprs.get(this).map { subExprState => + // This expression is repeated meaning the code to evaluated has already been added + // as a function, `subExprState.fnName`. Just call that. + val code = + s""" + |/* $this */ + |${subExprState.fnName}(${ctx.INPUT_ROW}); + """.stripMargin.trim + GeneratedExpressionCode(code, subExprState.code.isNull, subExprState.code.value) + }.getOrElse { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code.trim) + } } /** @@ -145,11 +156,37 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + // Non-deterministic expressions cannot be semantic equal + if (!deterministic || !other.deterministic) return false val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq checkSemantic(elements1, elements2) } + /** + * Returns the hash for this expression. Expressions that compute the same result, even if + * they differ cosmetically should return the same hash. + */ + def semanticHash() : Int = { + def computeHash(e: Seq[Any]): Int = { + // See http://stackoverflow.com/questions/113511/hash-code-implementation + var hash: Int = 17 + e.foreach(i => { + val h: Int = i match { + case e: Expression => e.semanticHash() + case Some(e: Expression) => e.semanticHash() + case t: Traversable[_] => computeHash(t.toSeq) + case null => 0 + case other => other.hashCode() + } + hash = hash * 37 + h + }) + hash + } + + computeHash(this.productIterator.toSeq) + } + /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 79dabe8e925ad..053e612f3ecb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -102,16 +102,6 @@ abstract class UnsafeProjection extends Projection { object UnsafeProjection { - /* - * Returns whether UnsafeProjection can support given StructType, Array[DataType] or - * Seq[Expression]. - */ - def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) - private def canSupport(types: Array[DataType]): Boolean = { - types.forall(GenerateUnsafeProjection.canSupport) - } - /** * Returns an UnsafeProjection for given StructType. */ @@ -144,6 +134,22 @@ object UnsafeProjection { def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + + /** + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ + def create( + exprs: Seq[Expression], + inputSchema: Seq[Attribute], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val e = exprs.map(BindReferences.bindReference(_, inputSchema)) + .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c8c20ada5fbc7..94ac4bf09b90b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Average(child: Expression) extends DeclarativeAggregate { @@ -32,36 +34,33 @@ case class Average(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } - private val sumDataType = child.dataType match { + private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } - private val sum = AttributeReference("sum", sumDataType)() - private val count = AttributeReference("count", LongType)() + private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = sum :: count :: Nil + override lazy val aggBufferAttributes = sum :: count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* sum = */ Cast(Literal(0), sumDataType), /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* sum = */ Add( sum, @@ -69,13 +68,13 @@ case class Average(child: Expression) extends DeclarativeAggregate { /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right ) // If all input are nulls, count will be 0 and we will get null after the division. - override val evaluateExpression = child.dataType match { + override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index ef08b025ff556..de5872ab11eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -55,13 +57,10 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def dataType: DataType = DoubleType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 832338378fb38..00d7436b710d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -35,6 +37,9 @@ case class Corr( inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + def this(left: Expression, right: Expression) = + this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def children: Seq[Expression] = Seq(left, right) override def nullable: Boolean = false @@ -43,6 +48,16 @@ case class Corr( override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"corr requires that both arguments are double type, " + + s"not (${left.dataType}, ${right.dataType}).") + } + } + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) override def inputAggBufferAttributes: Seq[AttributeReference] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index ec0c8b483a909..09a1da9200df0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -32,23 +32,39 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = count :: Nil + override lazy val aggBufferAttributes = count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = Cast(count, LongType) override def defaultResult: Option[Literal] = Option(Literal(0L)) } + +object Count { + def apply(children: Seq[Expression]): Count = { + // This is used to deal with COUNT DISTINCT. When we have multiple + // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). + // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any + // null in the arguments, we will not count that row. So, we use DropAnyNull at here + // to return a null when any field of the created STRUCT is null. + val child = if (children.size > 1) { + DropAnyNull(CreateStruct(children)) + } else { + children.head + } + Count(child) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9028143015853..35f57426feaf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -51,18 +51,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val first = AttributeReference("first", child.dataType)() + private lazy val first = AttributeReference("first", child.dataType)() - private val valueSet = AttributeReference("valueSet", BooleanType)() + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() - override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* first = */ Literal.create(null, child.dataType), /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* first = */ If(Or(valueSet, IsNull(child)), first, child), @@ -76,7 +76,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // For first, we can just check if valueSet.left is set to true. If it is set // to true, we use first.right. If not, we use first.right (even if valueSet.right is // false, we are safe to do so because first.right will be null in this case). @@ -86,7 +86,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara ) } - override val evaluateExpression: AttributeReference = first + override lazy val evaluateExpression: AttributeReference = first override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8d341ee630bdb..8a95c541f1e86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -22,6 +22,7 @@ import java.util import com.clearspring.analytics.hash.MurmurHash +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -55,6 +56,22 @@ case class HyperLogLogPlusPlus( extends ImperativeAggregate { import HyperLogLogPlusPlus._ + def this(child: Expression) = { + this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = relativeSD match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + }, + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 6da39e7143447..8fa3aac9f1a51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -24,6 +24,8 @@ case class Kurtosis(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -40,9 +42,11 @@ case class Kurtosis(child: Expression, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m4 = moments(4) + if (n == 0.0 || m2 == 0.0) { Double.NaN - } else { + } + else { n * m4 / (m2 * m2) - 3.0 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 8636bfe8d07aa..be7e12d7a2336 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -51,15 +51,15 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val last = AttributeReference("last", child.dataType)() + private lazy val last = AttributeReference("last", child.dataType)() - override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(child), last, child) @@ -71,7 +71,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(last.right), last.left, last.right) @@ -83,7 +83,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val evaluateExpression: AttributeReference = last + override lazy val evaluateExpression: AttributeReference = last override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index b9d75ad452838..61cae44cd0f5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Max(child: Expression) extends DeclarativeAggregate { @@ -32,24 +34,27 @@ case class Max(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val max = AttributeReference("max", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") - override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + private lazy val max = AttributeReference("max", child.dataType)() - override val initialValues: Seq[Literal] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( /* max = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val greatest = Greatest(Seq(max.left, max.right)) Seq( /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) ) } - override val evaluateExpression: AttributeReference = max + override lazy val evaluateExpression: AttributeReference = max } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 5ed9cd348daba..242456d9e2e18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -33,24 +35,27 @@ case class Min(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val min = AttributeReference("min", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") - override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + private lazy val min = AttributeReference("min", child.dataType)() - override val initialValues: Seq[Expression] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( /* min = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val least = Least(Seq(min.left, min.right)) Seq( /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) ) } - override val evaluateExpression: AttributeReference = min + override lazy val evaluateExpression: AttributeReference = min } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index 0def7ddfd9d3d..e1c01a5b82781 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -24,6 +24,8 @@ case class Skewness(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -39,9 +41,11 @@ case class Skewness(child: Expression, s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}") val m2 = moments(2) val m3 = moments(3) + if (n == 0.0 || m2 == 0.0) { Double.NaN - } else { + } + else { math.sqrt(n) * m3 / math.sqrt(m2 * m2 * m2) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 3f47ffe13cbc8..05dd5e3b22543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -17,118 +17,55 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ +case class StddevSamp(child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg(child) { - override def isSample: Boolean = false - override def prettyName: String = "stddev_pop" -} - - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg(child) { - override def isSample: Boolean = true - override def prettyName: String = "stddev_samp" -} - - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - def isSample: Boolean + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) - override def children: Seq[Expression] = child :: Nil + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) - override def nullable: Boolean = true - - override def dataType: DataType = resultType + override def prettyName: String = "stddev_samp" - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override protected val momentOrder = 2 - private val resultType = DoubleType + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - private val count = AttributeReference("count", resultType)() - private val avg = AttributeReference("avg", resultType)() - private val mk = AttributeReference("mk", resultType)() + if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0)) + } +} - override val aggBufferAttributes = count :: avg :: mk :: Nil +case class StddevPop( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends CentralMomentAgg(child) { - override val initialValues: Seq[Expression] = Seq( - /* count = */ Cast(Literal(0), resultType), - /* avg = */ Cast(Literal(0), resultType), - /* mk = */ Cast(Literal(0), resultType) - ) + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - override val updateExpressions: Seq[Expression] = { - val value = Cast(child, resultType) - val newCount = count + Cast(Literal(1), resultType) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) - // update average - // avg = avg + (value - avg)/count - val newAvg = avg + (value - avg) / newCount + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) - // update sum ofference from mean - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - val newMk = mk + (value - avg) * (value - newAvg) + override def prettyName: String = "stddev_pop" - Seq( - /* count = */ If(IsNull(child), count, newCount), - /* avg = */ If(IsNull(child), avg, newAvg), - /* mk = */ If(IsNull(child), mk, newMk) - ) - } + override protected val momentOrder = 2 - override val mergeExpressions: Seq[Expression] = { - - // count merge - val newCount = count.left + count.right - - // average merge - val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount - - // update sum of square differences - val newMk = { - val avgDelta = avg.right - avg.left - val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount - mk.left + mk.right + mkDelta - } - - Seq( - /* count = */ If(IsNull(count.left), count.right, - If(IsNull(count.right), count.left, newCount)), - /* avg = */ If(IsNull(avg.left), avg.right, - If(IsNull(avg.right), avg.left, newAvg)), - /* mk = */ If(IsNull(mk.left), mk.right, - If(IsNull(mk.right), mk.left, newMk)) - ) - } + override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = { + require(moments.length == momentOrder + 1, + s"$prettyName requires ${momentOrder + 1} central moment, received: ${moments.length}") - override val evaluateExpression: Expression = { - // when count == 0, return null - // when count == 1, return 0 - // when count >1 - // stddev_samp = sqrt (mk/(count -1)) - // stddev_pop = sqrt (mk/count) - val varCol = - if (isSample) { - mk / Cast(count - Cast(Literal(1), resultType), resultType) - } else { - mk / count - } - - If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), - Cast(Sqrt(varCol), resultType))) + if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 7f8adbc56ad1d..cfb042e0aa782 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Sum(child: Expression) extends DeclarativeAggregate { @@ -29,16 +31,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select sum(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) // TODO: Remove this line once we remove the NullType from inputTypes. @@ -46,24 +45,24 @@ case class Sum(child: Expression) extends DeclarativeAggregate { case _ => child.dataType } - private val sumDataType = resultType + private lazy val sumDataType = resultType - private val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", sumDataType)() - private val zero = Cast(Literal(0), sumDataType) + private lazy val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: Nil - override val initialValues: Seq[Expression] = Seq( + override lazy val initialValues: Seq[Expression] = Seq( /* sum = */ Literal.create(null, sumDataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* sum = */ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ @@ -71,5 +70,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate { ) } - override val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = Cast(sum, resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala deleted file mode 100644 index 39010c3be6d4e..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ /dev/null @@ -1,393 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.aggregate - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} - -/** - * Utility functions used by the query planner to convert our plan to new aggregation code path. - */ -object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - - !hasComplexTypes - } - - private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Kurtosis(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Kurtosis(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Skewness(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Skewness(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevPop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevPop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevSamp(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.Corr(left, right) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Corr(left, right), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.ApproxCountDistinct(child, rsd) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VariancePop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VariancePop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VarianceSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VarianceSamp(child), - mode = aggregate.Complete, - isDistinct = false) - }) - - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - // TODO remove this. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to see if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate => - val converted = doConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case other => None - } -} - -/** - * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double - * aggregation in which the regular aggregation expressions and every distinct clause is aggregated - * in a separate group. The results are then combined in a second aggregate. - * - * TODO Expression cannocalization - * TODO Eliminate foldable expressions from distinct clauses. - * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate - * operator. Perhaps this is a good thing? It is much simpler to plan later on... - */ -object MultipleDistinctRewriter extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case a: Aggregate => rewrite(a) - case p => p - } - - def rewrite(a: Aggregate): Aggregate = { - - // Collect all aggregate expressions. - val aggExpressions = a.aggregateExpressions.flatMap { e => - e.collect { - case ae: AggregateExpression2 => ae - } - } - - // Extract distinct aggregate expressions. - val distinctAggGroups = aggExpressions - .filter(_.isDistinct) - .groupBy(_.aggregateFunction.children.toSet) - - // Only continue to rewrite if there is more than one distinct group. - if (distinctAggGroups.size > 1) { - // Create the attributes for the grouping id and the group by clause. - val gid = new AttributeReference("gid", IntegerType, false)() - val groupByMap = a.groupingExpressions.collect { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() - } - val groupByAttrs = groupByMap.map(_._2) - - // Functions used to modify aggregate functions and their inputs. - def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) - def patchAggregateFunctionChildren( - af: AggregateFunction2, - id: Literal, - attrs: Map[Expression, Expression]): AggregateFunction2 = { - af.withNewChildren(af.children.map { case afc => - evalWithinGroup(id, attrs(afc)) - }).asInstanceOf[AggregateFunction2] - } - - // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap - val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq - - // Setup expand & aggregate operators for distinct aggregate expressions. - val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { - case ((group, expressions), i) => - val id = Literal(i + 1) - - // Expand projection - val projection = distinctAggChildren.map { - case e if group.contains(e) => e - case e => nullify(e) - } :+ id - - // Final aggregate - val operators = expressions.map { e => - val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) - (e, e.copy(aggregateFunction = naf, isDistinct = false)) - } - - (projection, operators) - } - - // Setup expand for the 'regular' aggregate expressions. - val regularAggExprs = aggExpressions.filter(!_.isDistinct) - val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap - - // Setup aggregates for 'regular' aggregate expressions. - val regularGroupId = Literal(0) - val regularAggOperatorMap = regularAggExprs.map { e => - // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren( - e.aggregateFunction, - regularGroupId, - regularAggChildAttrMap) - val a = Alias(e.copy(aggregateFunction = af), e.toString)() - - // Get the result of the first aggregate in the last aggregate. - val b = AggregateExpression2( - aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), - mode = Complete, - isDistinct = false) - - // Some aggregate functions (COUNT) have the special property that they can return a - // non-null result without any input. We need to make sure we return a result in this case. - val c = af.defaultResult match { - case Some(lit) => Coalesce(Seq(b, lit)) - case None => b - } - - (e, a, c) - } - - // Construct the regular aggregate input projection only if we need one. - val regularAggProjection = if (regularAggExprs.nonEmpty) { - Seq(a.groupingExpressions ++ - distinctAggChildren.map(nullify) ++ - Seq(regularGroupId) ++ - regularAggChildren) - } else { - Seq.empty[Seq[Expression]] - } - - // Construct the distinct aggregate input projections. - val regularAggNulls = regularAggChildren.map(nullify) - val distinctAggProjections = distinctAggOperatorMap.map { - case (projection, _) => - a.groupingExpressions ++ - projection ++ - regularAggNulls - } - - // Construct the expand operator. - val expand = Expand( - regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, - a.child) - - // Construct the first aggregate operator. This de-duplicates the all the children of - // distinct operators, and applies the regular aggregate operators. - val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid - val firstAggregate = Aggregate( - firstAggregateGroupBy, - firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), - expand) - - // Construct the second aggregate - val transformations: Map[Expression, Expression] = - (distinctAggOperatorMap.flatMap(_._2) ++ - regularAggOperatorMap.map(e => (e._1, e._3))).toMap - - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case e: Expression => - // The same GROUP BY clauses can have different forms (different names for instance) in - // the groupBy and aggregate expressions of an aggregate. This makes a map lookup - // tricky. So we do a linear search for a semantically equal group by expression. - groupByMap - .find(ge => e.semanticEquals(ge._1)) - .map(_._2) - .getOrElse(transformations.getOrElse(e, e)) - }.asInstanceOf[NamedExpression] - } - Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) - } else { - a - } - } - - private def nullify(e: Expression) = Literal.create(null, e.dataType) - - private def expressionAttributePair(e: Expression) = - // We are creating a new reference here instead of reusing the attribute in case of a - // NamedExpression. This is done to prevent collisions between distinct and regular aggregate - // children, in this case attribute reuse causes the input of the regular aggregate to bound to - // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.prettyName, e.dataType, true)() -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ec63534e5290a..ede2da2805966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -24,6 +24,8 @@ case class VarianceSamp(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -42,11 +44,14 @@ case class VarianceSamp(child: Expression, } } -case class VariancePop(child: Expression, +case class VariancePop( + child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 5c5b3d1ccd3cd..3b441de34a49f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction2]]. */ +/** The mode of an [[AggregateFunction]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -41,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -49,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly + * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. @@ -67,13 +68,15 @@ private[sql] case object NoOp extends Expression with Unevaluable { } /** - * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. */ -private[sql] case class AggregateExpression2( - aggregateFunction: AggregateFunction2, +private[sql] case class AggregateExpression( + aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) extends AggregateExpression { + isDistinct: Boolean) + extends Expression + with Unevaluable { override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType @@ -89,6 +92,8 @@ private[sql] case class AggregateExpression2( AttributeSet(childReferences) } + override def prettyString: String = aggregateFunction.prettyString + override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } @@ -106,10 +111,10 @@ private[sql] case class AggregateExpression2( * combined aggregation buffer which concatenates the aggregation buffers of the individual * aggregate functions. * - * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of + * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of * aggregate functions. */ -sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { +sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false @@ -141,6 +146,27 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct + * field of the [[AggregateExpression]] to the given value because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { + AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) + } } /** @@ -161,7 +187,7 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` * and `inputAggBufferAttributes`. */ -abstract class ImperativeAggregate extends AggregateFunction2 { +abstract class ImperativeAggregate extends AggregateFunction { /** * The offset of this function's first buffer value in the underlying shared mutable aggregation @@ -258,9 +284,14 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and * `evaluateExpressions`. + * + * Please note that children of an aggregate function can be unresolved (it will happen when + * we create this function in DataFrame API). So, if there is any fields in + * the implemented class that need to access fields of its children, please make + * those fields `lazy val`s. */ abstract class DeclarativeAggregate - extends AggregateFunction2 + extends AggregateFunction with Serializable with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala deleted file mode 100644 index 3dcf7915d77b3..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import com.clearspring.analytics.stream.cardinality.HyperLogLog - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils} -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - - -trait AggregateExpression extends Expression with Unevaluable - -trait AggregateExpression1 extends AggregateExpression { - - /** - * Aggregate expressions should not be foldable. - */ - override def foldable: Boolean = false - - /** - * Creates a new instance that can be used to compute this aggregate expression for a group - * of input rows/ - */ - def newInstance(): AggregateFunction1 -} - -/** - * Represents an aggregation that has been rewritten to be performed in two steps. - * - * @param finalEvaluation an aggregate expression that evaluates to same final result as the - * original aggregation. - * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial - * data sets and are required to compute the `finalEvaluation`. - */ -case class SplitEvaluation( - finalEvaluation: Expression, - partialEvaluations: Seq[NamedExpression]) - -/** - * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. - * These partial evaluations can then be combined to compute the actual answer. - */ -trait PartialAggregate1 extends AggregateExpression1 { - - /** - * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. - */ - def asPartial: SplitEvaluation -} - -/** - * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. - */ -abstract class AggregateFunction1 extends LeafExpression with Serializable { - - /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression1 - - override def nullable: Boolean = base.nullable - override def dataType: DataType = base.dataType - - def update(input: InternalRow): Unit - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - throw new UnsupportedOperationException( - "AggregateFunction1 should not be used for generated aggregates") - } -} - -case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMin = Alias(Min(child), "PartialMin")() - SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) - } - - override def newInstance(): MinFunction = new MinFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function min") -} - -case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = GreaterThan(currentMin, expr) - - override def update(input: InternalRow): Unit = { - if (currentMin.value == null) { - currentMin.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMin.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMin.value -} - -case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMax = Alias(Max(child), "PartialMax")() - SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) - } - - override def newInstance(): MaxFunction = new MaxFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function max") -} - -case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = LessThan(currentMax, expr) - - override def update(input: InternalRow): Unit = { - if (currentMax.value == null) { - currentMax.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMax.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMax.value -} - -case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - - override def asPartial: SplitEvaluation = { - val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) - } - - override def newInstance(): CountFunction = new CountFunction(child, this) -} - -case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - var count: Long = _ - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } - } - - override def eval(input: InternalRow): Any = count -} - -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(expressions), "partialSets")() - SplitEvaluation( - CombineSetsAndCount(partialSet.toAttribute), - partialSet :: Nil) - } -} - -case class CountDistinctFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) - override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = - new CollectHashSetFunction(expressions, this) -} - -case class CollectHashSetFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = { - seen - } -} - -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"CombineAndCount($inputSet)" - override def newInstance(): CombineSetsAndCountFunction = { - new CombineSetsAndCountFunction(inputSet, this) - } -} - -case class CombineSetsAndCountFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ -private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { - - override def sqlType: DataType = BinaryType - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def serialize(obj: Any): Array[Byte] = - obj.asInstanceOf[HyperLogLog].getBytes - - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def deserialize(datum: Any): HyperLogLog = - HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) - - override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] -} - -case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: DataType = HyperLogLogUDT - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctPartitionFunction = { - new ApproxCountDistinctPartitionFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctMergeFunction = { - new ApproxCountDistinctMergeFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } - - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() -} - -case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - - override def asPartial: SplitEvaluation = { - val partialCount = - Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")() - - SplitEvaluation( - ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD), - partialCount :: Nil) - } - - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) -} - -case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def prettyName: String = "avg" - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 4 digits after decimal point, like Hive - DecimalType.bounded(precision + 4, scale + 4) - case _ => - DoubleType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(precision, scale) => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - // partialSum already increase the precision by 10 - val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } - - override def newInstance(): AverageFunction = new AverageFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") -} - -case class AverageFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), calcType) - - private def addFunction(value: Any) = Add(sum, - Cast(Literal.create(value, expr.dataType), calcType)) - - override def eval(input: InternalRow): Any = { - if (count == 0L) { - null - } else { - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - val dt = DecimalType.bounded(precision + 14, scale + 4) - Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) - case _ => - Divide( - Cast(sum, dataType), - Cast(Literal(count), dataType)).eval(null) - } - } - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1 - sum.update(addFunction(evaluatedExpr), input) - } - } -} - -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Cast(Sum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Sum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") -} - -case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) - - override def update(input: InternalRow): Unit = { - sum.update(addFunction, input) - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - override def toString: String = s"sum(distinct $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") -} - -case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val seen = new scala.collection.mutable.HashSet[Any]() - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - seen += evaluatedExpr - } - } - - override def eval(input: InternalRow): Any = { - if (seen.size == 0) { - null - } else { - Cast(Literal( - seen.reduceLeft( - dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - dataType).eval(null) - } - } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next()) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.get(0, null)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute, ignoreNulls), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this) -} - -object First { - def apply(child: Expression): First = First(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): First = - First(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class FirstFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - private[this] var result: Any = null - - private[this] var valueSet: Boolean = false - - override def update(input: InternalRow): Unit = { - if (!valueSet) { - val value = expr.eval(input) - // When we have not set the result, we will set the result if we respect nulls - // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null. - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - valueSet = true - } - } - } - - override def eval(input: InternalRow): Any = result -} - -case class Last( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute, ignoreNulls), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this) -} - -object Last { - def apply(child: Expression): Last = Last(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): Last = - Last(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class LastFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - var result: Any = null - - override def update(input: InternalRow): Unit = { - val value = expr.eval(input) - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - } - } - - override def eval(input: InternalRow): Any = { - result - } -} - -/** - * Calculate Pearson Correlation Coefficient for the given columns. - * Only support AggregateExpression2. - * - */ -case class Corr(left: Expression, right: Expression) - extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes { - override def nullable: Boolean = false - override def dataType: DoubleType.type = DoubleType - override def toString: String = s"corr($left, $right)" - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException( - "Corr only supports the new AggregateExpression2 and can only be used " + - "when spark.sql.useAggregate2 = true") - } -} - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { - override def nullable: Boolean = true - override def dataType: DataType = DoubleType - - def isSample: Boolean - - override def asPartial: SplitEvaluation = { - val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() - SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) - } - - override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function stddev") - -} - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_pop($child)" - override def isSample: Boolean = false -} - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_samp($child)" - override def isSample: Boolean = true -} - -case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = false - override def dataType: DataType = ArrayType(DoubleType) - override def toString: String = s"computePartialStddev($child)" - override def newInstance(): ComputePartialStdFunction = - new ComputePartialStdFunction(child, this) -} - -case class ComputePartialStdFunction ( - expr: Expression, - base: AggregateExpression1 - ) extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private var partialCount: Long = 0L - - // the mean of data processed so far - private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update average based on this formula: - // avg = avg + (value - avg)/count - private def avgAddFunction (value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), partialAvg) - Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) - } - - // the sum of squares of difference from mean - private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update sum of square of difference from mean based on following formula: - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), prePartialAvg) - val delta2 = Subtract(Cast(value, computeType), partialAvg) - Add(partialMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - val prePartialAvg = partialAvg.copy() - partialCount += 1 - partialAvg.update(avgAddFunction(exprValue), input) - partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), - partialAvg.eval(null), - partialMk.eval(null))) - } -} - -case class MergePartialStd( - child: Expression, - isSample: Boolean -) extends UnaryExpression with AggregateExpression1 { - def this() = this(null, false) // required for serialization - - override def children: Seq[Expression] = child:: Nil - override def nullable: Boolean = false - override def dataType: DataType = DoubleType - override def toString: String = s"MergePartialStd($child)" - override def newInstance(): MergePartialStdFunction = { - new MergePartialStdFunction(child, this, isSample) - } -} - -case class MergePartialStdFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - def this() = this (null, null, false) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private val combineCount = MutableLiteral(zero.eval(null), computeType) - private val combineAvg = MutableLiteral(zero.eval(null), computeType) - private val combineMk = MutableLiteral(zero.eval(null), computeType) - - private def avgUpdateFunction(preCount: Expression, - partialCount: Expression, - partialAvg: Expression): Expression = { - Divide(Add(Multiply(combineAvg, preCount), - Multiply(partialAvg, partialCount)), - Add(preCount, partialCount)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] - - if (evaluatedExpr != null) { - val exprValue = evaluatedExpr.toArray(computeType) - val (partialCount, partialAvg, partialMk) = - (Literal.create(exprValue(0), computeType), - Literal.create(exprValue(1), computeType), - Literal.create(exprValue(2), computeType)) - - if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { - val preCount = combineCount.copy() - combineCount.update(Add(combineCount, partialCount), input) - - val preAvg = combineAvg.copy() - val avgDelta = Subtract(partialAvg, preAvg) - val mkDelta = Multiply(Multiply(avgDelta, avgDelta), - Divide(Multiply(preCount, partialCount), - combineCount)) - - // update average based on following formula - // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) - combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) - - // update sum of square differences from mean based on following formula - // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) - combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) - } - } - } - - override def eval(input: InternalRow): Any = { - val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] - - if (count == 0) null - else if (count < 2) zero.eval(null) - else { - // when total count > 2 - // stddev_samp = sqrt (combineMk/(combineCount -1)) - // stddev_pop = sqrt (combineMk/combineCount) - val varCol = { - if (isSample) { - Divide(combineMk, Cast(Literal(count - 1), computeType)) - } - else { - Divide(combineMk, Cast(Literal(count), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -case class StddevFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - - def this() = this(null, null, false) // Required for serialization - - private val computeType = DoubleType - private var curCount: Long = 0L - private val zero = Cast(Literal(0), computeType) - private val curAvg = MutableLiteral(zero.eval(null), computeType) - private val curMk = MutableLiteral(zero.eval(null), computeType) - - private def curAvgAddFunction(value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), curAvg) - Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) - } - private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), preAvg) - val delta2 = Subtract(Cast(value, computeType), curAvg) - Add(curMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val preAvg: MutableLiteral = curAvg.copy() - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - curCount += 1L - curAvg.update(curAvgAddFunction(exprValue), input) - curMk.update(curMkAddFunction(exprValue, preAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - if (curCount == 0) null - else if (curCount < 2) zero.eval(null) - else { - // when total count > 2, - // stddev_samp = sqrt(curMk/(curCount - 1)) - // stddev_pop = sqrt(curMk/curCount) - val varCol = { - if (isSample) { - Divide(curMk, Cast(Literal(curCount - 1), computeType)) - } - else { - Divide(curMk, Cast(Literal(curCount), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -// placeholder -case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "kurtosis" -} - -// placeholder -case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "skewness" -} - -// placeholder -case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_pop" -} - -// placeholder -case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_samp" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f0f7a6cf0cc4d..1718cfbd35332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -33,10 +33,6 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ -// These classes are here to avoid issues with serialization and integration with quasiquotes. -class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] -class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] - /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. * @@ -92,6 +88,33 @@ class CodeGenContext { addedFunctions += ((funcName, funcCode)) } + /** + * Holds expressions that are equivalent. Used to perform subexpression elimination + * during codegen. + * + * For expressions that appear more than once, generate additional code to prevent + * recomputing the value. + * + * For example, consider two exprsesion generated from this SQL statement: + * SELECT (col1 + col2), (col1 + col2) / col3. + * + * equivalentExpressions will match the tree containing `col1 + col2` and it will only + * be evaluated once. + */ + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + + // State used for subexpression elimination. + case class SubExprEliminationState( + isLoaded: String, + code: GeneratedExpressionCode, + fnName: String) + + // Foreach expression that is participating in subexpression elimination, the state to use. + val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + + // The collection of isLoaded variables that need to be reset on each row. + val subExprIsLoadedVariables = mutable.ArrayBuffer.empty[String] + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -178,8 +201,6 @@ class CodeGenContext { case _: StructType => "InternalRow" case _: ArrayType => "ArrayData" case _: MapType => "MapData" - case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName - case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName @@ -246,6 +267,49 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" + case array: ArrayType => + val elementType = array.elementType + val elementA = freshName("elementA") + val isNullA = freshName("isNullA") + val elementB = freshName("elementB") + val isNullB = freshName("isNullB") + val compareFunc = freshName("compareArray") + val minLength = freshName("minLength") + val funcCode: String = + s""" + public int $compareFunc(ArrayData a, ArrayData b) { + int lengthA = a.numElements(); + int lengthB = b.numElements(); + int $minLength = (lengthA > lengthB) ? lengthB : lengthA; + for (int i = 0; i < $minLength; i++) { + boolean $isNullA = a.isNullAt(i); + boolean $isNullB = b.isNullAt(i); + if ($isNullA && $isNullB) { + // Nothing + } else if ($isNullA) { + return -1; + } else if ($isNullB) { + return 1; + } else { + ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; + ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + int comp = ${genComp(elementType, elementA, elementB)}; + if (comp != 0) { + return comp; + } + } + } + + if (lengthA < lengthB) { + return -1; + } else if (lengthA > lengthB) { + return 1; + } + return 0; + } + """ + addNewFunction(compareFunc, funcCode) + s"this.$compareFunc($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") @@ -317,6 +381,87 @@ class CodeGenContext { functions.map(name => s"$name($row);").mkString("\n") } } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpresses, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]) = { + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the exprs that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + commonExprs.foreach(e => { + val expr = e.head + val isLoaded = freshName("isLoaded") + val isNull = freshName("isNull") + val value = freshName("value") + val fnName = freshName("evalExpr") + + // Generate the code for this expression tree and wrap it in a function. + val code = expr.gen(this) + val fn = + s""" + |private void $fnName(InternalRow ${INPUT_ROW}) { + | if (!$isLoaded) { + | ${code.code.trim} + | $isLoaded = true; + | $isNull = ${code.isNull}; + | $value = ${code.value}; + | } + |} + """.stripMargin + code.code = fn + code.isNull = isNull + code.value = value + + addNewFunction(fnName, fn) + + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly + // very often. The reason it is not loaded is because of a prior branch. + // 3. Extra store into isLoaded. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + + // Maintain the loaded value and isNull as member variables. This is necessary if the codegen + // function is split across multiple functions. + // TODO: maintaining this as a local variable probably allows the compiler to do better + // optimizations. + addMutableState("boolean", isLoaded, s"$isLoaded = false;") + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(javaType(expr.dataType), value, + s"$value = ${defaultValue(expr.dataType)};") + subExprIsLoadedVariables += isLoaded + + val state = SubExprEliminationState(isLoaded, code, fnName) + e.foreach(subExprEliminationExprs.put(_, state)) + }) + } + + /** + * Generates code for expressions. If doSubexpressionElimination is true, subexpression + * elimination will be performed. Subexpression elimination assumes that the code will for each + * expression will be combined in the `expressions` order. + */ + def generateExpressions(expressions: Seq[Expression], + doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions) + expressions.map(e => e.gen(this)) + } } /** @@ -349,7 +494,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2136f82ba4752..4c17d02a23725 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,7 +39,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true - case dt: OpenHashSetUDT => false // it's not a standard UDT case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -139,9 +138,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" ${input.code} if (${input.isNull}) { - $setNull + ${setNull.trim} } else { - $writeField + ${writeField.trim} } """ } @@ -149,7 +148,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $rowWriter.initialize($bufferHolder, ${inputs.length}); ${ctx.splitExpressions(row, writeFields)} - """ + """.trim } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -275,8 +274,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { - val exprEvals = expressions.map(e => e.gen(ctx)) + def createCode( + ctx: CodeGenContext, + expressions: Seq[Expression], + useSubexprElimination: Boolean = false): GeneratedExpressionCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) val result = ctx.freshName("result") @@ -285,10 +287,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + // Reset the isLoaded flag for each row. + val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n") + val code = s""" $bufferHolder.reset(); + $subexprReset ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} + $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); """ GeneratedExpressionCode(code, "false", result) @@ -300,10 +307,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled) + } + protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() + create(expressions, subexpressionEliminationEnabled = false) + } - val eval = createCode(ctx, expressions) + private def create( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val ctx = newCodeGenContext() + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" public Object generate($exprType[] exprs) { @@ -315,6 +333,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificUnsafeProjection($exprType[] expressions) { @@ -328,7 +347,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code} + ${eval.code.trim} return ${eval.value}; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2cf19b939f734..741ad1f3efd8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } @@ -90,6 +91,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 41cd0a104a1f5..f871b737fff3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -97,11 +97,16 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. + * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { - override def dataType: DataType = field.dataType + override def dataType: DataType = child.dataType match { + case s: StructType => s(ordinal).dataType + // This is a hack to avoid breaking existing code until we remove the need for the struct field + case _ => field.dataType + } override def nullable: Boolean = child.nullable || field.nullable override def toString: String = s"$child.${field.name}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d532629984bec..0d4af43978ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{NullType, BooleanType, DataType} +import org.apache.spark.sql.types._ case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -419,3 +419,31 @@ case class Greatest(children: Seq[Expression]) extends Expression { """ } } + +/** Operator that drops a row when it contains any nulls. */ +case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(StructType) + + protected override def nullSafeEval(input: Any): InternalRow = { + val row = input.asInstanceOf[InternalRow] + if (row.anyNull) { + null + } else { + row + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.anyNull()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $eval; + } + """ + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 13cc6bb6f27b8..03c39f8404e78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -299,7 +299,20 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx } /** - * Converts time string with given pattern + * Converts time string with given pattern. + * Deterministic version of [[UnixTimestamp]], must have at least one parameter. + */ +case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } +} + +/** + * Converts time string with given pattern. * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) * to Unix time stamp (in seconds), returns null if fail. * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. @@ -308,9 +321,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx * If the first parameter is a Date or Timestamp instead of String, we will ignore the * second parameter. */ -case class UnixTimestamp(timeExp: Expression, format: Expression) - extends BinaryExpression with ExpectsInputTypes { - +case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -321,6 +332,9 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) def this() = { this(CurrentTimestamp()) } +} + +abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, DateType, TimestampType), StringType) @@ -347,7 +361,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) null } case StringType => - val f = format.eval(input) + val f = right.eval(input) if (f == null) { null } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8c9853e628d2c..8cd73236a7876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -314,7 +314,7 @@ case class GetJsonObject(json: Expression, path: Expression) } case class JsonTuple(children: Seq[Expression]) - extends Expression with CodegenFallback { + extends Generator with CodegenFallback { import SharedFactory._ @@ -324,8 +324,8 @@ case class JsonTuple(children: Seq[Expression]) } // if processing fails this shared value will be returned - @transient private lazy val nullRow: InternalRow = - new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) + @transient private lazy val nullRow: Seq[InternalRow] = + new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) :: Nil // the json body is the first child @transient private lazy val jsonExpr: Expression = children.head @@ -344,15 +344,8 @@ case class JsonTuple(children: Seq[Expression]) // and count the number of foldable fields, we'll use this later to optimize evaluation @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) - override lazy val dataType: StructType = { - val fields = fieldExpressions.zipWithIndex.map { - case (_, idx) => StructField( - name = s"c$idx", // mirroring GenericUDTFJSONTuple.initialize - dataType = StringType, - nullable = true) - } - - StructType(fields) + override def elementTypes: Seq[(DataType, Boolean, String)] = fieldExpressions.zipWithIndex.map { + case (_, idx) => (StringType, true, s"c$idx") } override def prettyName: String = "json_tuple" @@ -367,7 +360,7 @@ case class JsonTuple(children: Seq[Expression]) } } - override def eval(input: InternalRow): InternalRow = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val json = jsonExpr.eval(input).asInstanceOf[UTF8String] if (json == null) { return nullRow @@ -383,7 +376,7 @@ case class JsonTuple(children: Seq[Expression]) } } - private def parseRow(parser: JsonParser, input: InternalRow): InternalRow = { + private def parseRow(parser: JsonParser, input: InternalRow): Seq[InternalRow] = { // only objects are supported if (parser.nextToken() != JsonToken.START_OBJECT) { return nullRow @@ -433,7 +426,7 @@ case class JsonTuple(children: Seq[Expression]) parser.skipChildren() } - new GenericInternalRow(row) + new GenericInternalRow(row) :: Nil } private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9ab5c299d0f55..f80bcfcb0b0bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -203,6 +203,10 @@ case class AttributeReference( case _ => false } + override def semanticHash(): Int = { + this.exprId.hashCode() + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4f58464221b4b..5cd19de68391c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -113,7 +113,7 @@ case class Invoke( arguments: Seq[Expression] = Nil) extends Expression { override def nullable: Boolean = true - override def children: Seq[Expression] = targetObject :: Nil + override def children: Seq[Expression] = arguments.+:(targetObject) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -343,33 +343,35 @@ case class MapObjects( private lazy val loopAttribute = AttributeReference("loopVar", elementType)() private lazy val completeFunction = function(loopAttribute) + private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case IntegerType => (i: String) => s".getInt($i)" + case LongType => (i: String) => s".getLong($i)" + case FloatType => (i: String) => s".getFloat($i)" + case DoubleType => (i: String) => s".getDouble($i)" + case ByteType => (i: String) => s".getByte($i)" + case ShortType => (i: String) => s".getShort($i)" + case BooleanType => (i: String) => s".getBoolean($i)" + case StringType => (i: String) => s".getUTF8String($i)" + case s: StructType => (i: String) => s".getStruct($i, ${s.size})" + case a: ArrayType => (i: String) => s".getArray($i)" + case _: MapType => (i: String) => s".getMap($i)" + case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) + } + private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) - case ArrayType(s: StructType, _) => - (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false) - case ArrayType(a: ArrayType, _) => - (".numElements()", (i: String) => s".getArray($i)", true) - case ArrayType(IntegerType, _) => - (".numElements()", (i: String) => s".getInt($i)", true) - case ArrayType(LongType, _) => - (".numElements()", (i: String) => s".getLong($i)", true) - case ArrayType(FloatType, _) => - (".numElements()", (i: String) => s".getFloat($i)", true) - case ArrayType(DoubleType, _) => - (".numElements()", (i: String) => s".getDouble($i)", true) - case ArrayType(ByteType, _) => - (".numElements()", (i: String) => s".getByte($i)", true) - case ArrayType(ShortType, _) => - (".numElements()", (i: String) => s".getShort($i)", true) - case ArrayType(BooleanType, _) => - (".numElements()", (i: String) => s".getBoolean($i)", true) - case ArrayType(StringType, _) => - (".numElements()", (i: String) => s".getUTF8String($i)", false) - case ArrayType(_: MapType, _) => - (".numElements()", (i: String) => s".getMap($i)", false) + case ArrayType(t, _) => + val (sqlType, primitiveElement) = t match { + case m: MapType => (m, false) + case s: StructType => (s, false) + case s: StringType => (s, false) + case udt: UserDefinedType[_] => (udt.sqlType, false) + case o => (o, true) + } + (".numElements()", itemAccessorMethod(sqlType), primitiveElement) } override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 6407c73bc97d9..6112259fed619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -48,6 +48,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right) case dt: AtomicType if order.direction == Descending => dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + case a: ArrayType if order.direction == Ascending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) + case a: ArrayType if order.direction == Descending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case s: StructType if order.direction == Ascending => s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => @@ -86,6 +90,8 @@ object RowOrdering { case NullType => true case dt: AtomicType => true case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) + case array: ArrayType => isOrderable(array.elementType) + case udt: UserDefinedType[_] => isOrderable(udt.sqlType) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala deleted file mode 100644 index d124d29d534b8..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - -/** The data type for expressions returning an OpenHashSet as the result. */ -private[sql] class OpenHashSetUDT( - val elementType: DataType) extends UserDefinedType[OpenHashSet[Any]] { - - override def sqlType: DataType = ArrayType(elementType) - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def serialize(obj: Any): Seq[Any] = { - obj.asInstanceOf[OpenHashSet[Any]].iterator.toSeq - } - - /** Since we are using OpenHashSet internally, usually it will not be called. */ - override def deserialize(datum: Any): OpenHashSet[Any] = { - val iterator = datum.asInstanceOf[Seq[Any]].iterator - val set = new OpenHashSet[Any] - while(iterator.hasNext) { - set.add(iterator.next()) - } - - set - } - - override def userClass: Class[OpenHashSet[Any]] = classOf[OpenHashSet[Any]] - - private[spark] override def asNullable: OpenHashSetUDT = this -} - -/** - * Creates a new set of the specified type - */ -case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback { - - override def nullable: Boolean = false - - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(elementType) - - override def eval(input: InternalRow): Any = { - new OpenHashSet[Any]() - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - elementType match { - case IntegerType | LongType => - ev.isNull = "false" - s""" - ${ctx.javaType(dataType)} ${ev.value} = new ${ctx.javaType(dataType)}(); - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"new Set($dataType)" -} - -/** - * Adds an item to a set. - * For performance, this expression mutates its input during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class AddItemToSet(item: Expression, set: Expression) - extends Expression with CodegenFallback { - - override def children: Seq[Expression] = item :: set :: Nil - - override def nullable: Boolean = set.nullable - - override def dataType: DataType = set.dataType - - override def eval(input: InternalRow): Any = { - val itemEval = item.eval(input) - val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] - - if (itemEval != null) { - if (setEval != null) { - setEval.add(itemEval) - setEval - } else { - null - } - } else { - setEval - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val itemEval = item.gen(ctx) - val setEval = set.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = "false" - ev.value = setEval.value - itemEval.code + setEval.code + s""" - if (!${itemEval.isNull} && !${setEval.isNull}) { - (($htype)${setEval.value}).add(${itemEval.value}); - } - """ - case _ => super.genCode(ctx, ev) - } - } - - override def toString: String = s"$set += $item" -} - -/** - * Combines the elements of two sets. - * For performance, this expression mutates its left input set during evaluation. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CombineSets(left: Expression, right: Expression) - extends BinaryExpression with CodegenFallback { - - override def nullable: Boolean = left.nullable - override def dataType: DataType = left.dataType - - override def eval(input: InternalRow): Any = { - val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] - if(leftEval != null) { - val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] - if (rightEval != null) { - val iterator = rightEval.iterator - while(iterator.hasNext) { - val rightValue = iterator.next() - leftEval.add(rightValue) - } - } - leftEval - } else { - null - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - elementType match { - case IntegerType | LongType => - val leftEval = left.gen(ctx) - val rightEval = right.gen(ctx) - val htype = ctx.javaType(dataType) - - ev.isNull = leftEval.isNull - ev.value = leftEval.value - leftEval.code + rightEval.code + s""" - if (!${leftEval.isNull} && !${rightEval.isNull}) { - ${leftEval.value}.union((${htype})${rightEval.value}); - } - """ - case _ => super.genCode(ctx, ev) - } - } -} - -/** - * Returns the number of elements in the input set. - * Note: this expression is internal and created only by the GeneratedAggregate, - * we don't need to do type check for it. - */ -case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback { - - override def dataType: DataType = LongType - - protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[OpenHashSet[Any]].size.toLong - - override def toString: String = s"$child.count()" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d222dfa33ad8a..f4dba67f13b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.LeftOuter @@ -201,8 +202,8 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case a @ Aggregate(_, _, e @ Expand(_, _, child)) - if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references))) + if (child.outputSet -- e.references -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => @@ -363,7 +364,8 @@ object LikeSimplification extends Rule[LogicalPlan] { object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) + case e @ AggregateExpression(Count(Literal(null, _)), _, _) => + Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) @@ -375,7 +377,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ Count(expr) if !expr.nullable => Count(Literal(1)) + case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable => + // This rule should be only triggered when isDistinct field is false. + AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -857,12 +861,15 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) Cast( - Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3b975b904a332..6f4f11406d7c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -84,80 +84,6 @@ object PhysicalOperation extends PredicateHelper { } } -/** - * Matches a logical aggregation that can be performed on distributed data in two steps. The first - * operates on the data in each partition performing partial aggregation for each group. The second - * occurs after the shuffle and completes the aggregation. - * - * This pattern will only match if all aggregate expressions can be computed partially and will - * return the rewritten aggregation expressions for both phases. - * - * The returned values for this match are as follows: - * - Grouping attributes for the final aggregation. - * - Aggregates for the final aggregation. - * - Grouping expressions for the partial aggregation. - * - Partial aggregate expressions. - * - Input to the aggregation. - */ -object PartialAggregation { - type ReturnType = - (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = - partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - } - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { - case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => - partialEvaluations(new TreeNodeRef(e)).finalEvaluation - - case e: Expression => - namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals e => ne.toAttribute - }.getOrElse(e) - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = namedGroupingExpressions.map(_._2) ++ - partialEvaluations.values.flatMap(_.partialEvaluations) - - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - - Some( - (namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child)) - } else { - None - } - case _ => None - } -} - - /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0ec9f08571082..b9db7838db08a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -137,13 +137,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Returns all of the expressions present in this query plan operator. */ def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Traversable[_] => seqToExpressions(s) + case other => Nil + } + productIterator.flatMap { case e: Expression => e :: Nil case Some(e: Expression) => e :: Nil - case seq: Traversable[_] => seq.flatMap { - case e: Expression => e :: Nil - case other => Nil - } + case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index fb963e2f8f7e7..e2b97b27a6c2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Utils +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -219,8 +220,6 @@ case class Aggregate( !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } @@ -306,6 +305,9 @@ case class Expand( output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + override def statistics: Statistics = { // TODO shouldn't we factor in the size of the projection versus the size of the backing child // row? @@ -384,6 +386,20 @@ case class Rollup( this.copy(aggregations = aggs) } +case class Pivot( + groupByExprs: Seq[NamedExpression], + pivotColumn: Expression, + pivotValues: Seq[Literal], + aggregates: Seq[Expression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { + case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) + } + } +} + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -466,10 +482,13 @@ case class MapPartitions[T, U]( } /** Factory for constructing new `AppendColumn` nodes. */ -object AppendColumn { - def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { +object AppendColumns { + def apply[T, U : Encoder]( + func: T => U, + tEncoder: ExpressionEncoder[T], + child: LogicalPlan): AppendColumns[T, U] = { val attrs = encoderFor[U].schema.toAttributes - new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) + new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child) } } @@ -478,7 +497,7 @@ object AppendColumn { * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to * decode/encode from the JVM object representation expected by `func.` */ -case class AppendColumn[T, U]( +case class AppendColumns[T, U]( func: T => U, tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], @@ -490,14 +509,16 @@ case class AppendColumn[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { - def apply[K : Encoder, T : Encoder, U : Encoder]( - func: (K, Iterator[T]) => Iterator[U], + def apply[K, T, U : Encoder]( + func: (K, Iterator[T]) => TraversableOnce[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], groupingAttributes: Seq[Attribute], child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( func, - encoderFor[K], - encoderFor[T], + kEncoder, + tEncoder, encoderFor[U], groupingAttributes, encoderFor[U].schema.toAttributes, @@ -511,7 +532,7 @@ object MapGroups { * object representation of all the rows with that key. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], @@ -524,7 +545,7 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, @@ -548,7 +569,7 @@ object CoGroup { * right children. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f5fff90e5a542..8fb3f41f1bd6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -110,7 +110,7 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - var indexOfGMT = s.indexOf("GMT"); + val indexOfGMT = s.indexOf("GMT") if (indexOfGMT != -1) { // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) val s0 = s.substring(0, indexOfGMT) @@ -395,16 +395,19 @@ object DateTimeUtils { /** * Returns the microseconds since year zero (-17999) from microseconds since epoch. */ - def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { + private def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { microsec + toYearZero * MICROS_PER_DAY } + private def localTimestamp(microsec: SQLTimestamp): SQLTimestamp = { + absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + } + /** * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. */ def getHours(microsec: SQLTimestamp): Int = { - val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L - ((localTs / MICROS_PER_SECOND / 3600) % 24).toInt + ((localTimestamp(microsec) / MICROS_PER_SECOND / 3600) % 24).toInt } /** @@ -412,8 +415,7 @@ object DateTimeUtils { * microseconds. */ def getMinutes(microsec: SQLTimestamp): Int = { - val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L - ((localTs / MICROS_PER_SECOND / 60) % 60).toInt + ((localTimestamp(microsec) / MICROS_PER_SECOND / 60) % 60).toInt } /** @@ -421,7 +423,7 @@ object DateTimeUtils { * microseconds. */ def getSeconds(microsec: SQLTimestamp): Int = { - ((absoluteMicroSecond(microsec) / MICROS_PER_SECOND) % 60).toInt + ((localTimestamp(microsec) / MICROS_PER_SECOND) % 60).toInt } private[this] def isLeapYear(year: Int): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index e9bf7b33e35be..96588bb5dc1bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -23,7 +23,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { - def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray) + def this(seq: Seq[Any]) = this(seq.toArray) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index bcf4d78fb9371..f603cbfb0cc21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -57,6 +57,7 @@ object TypeUtils { def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] + case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 1d2d007c2b4d2..a5ae8bb0e5eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -84,6 +84,7 @@ private[sql] object TypeCollection { * Types that can be ordered/compared. In the long run we should probably make this a trait * that can be mixed into each data type, and perhaps create an [[AbstractDataType]]. */ + // TODO: Should we consolidate this with RowOrdering.isOrderable? val Ordered = TypeCollection( BooleanType, ByteType, ShortType, IntegerType, LongType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5770f59b53077..a001eadcc61d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.types +import org.apache.spark.sql.catalyst.util.ArrayData import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi +import scala.math.Ordering + object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ @@ -81,4 +84,49 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || elementType.existsRecursively(f) } + + @transient + private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] { + private[this] val elementOrdering: Ordering[Any] = elementType match { + case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]] + case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case other => + throw new IllegalArgumentException(s"Type $other does not support ordered operations") + } + + def compare(x: ArrayData, y: ArrayData): Int = { + val leftArray = x + val rightArray = y + val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements()) + var i = 0 + while (i < minLength) { + val isNullLeft = leftArray.isNullAt(i) + val isNullRight = rightArray.isNullAt(i) + if (isNullLeft && isNullRight) { + // Do nothing. + } else if (isNullLeft) { + return -1 + } else if (isNullRight) { + return 1 + } else { + val comp = + elementOrdering.compare( + leftArray.get(i, elementType), + rightArray.get(i, elementType)) + if (comp != 0) { + return comp + } + } + i += 1 + } + if (leftArray.numElements() < rightArray.numElements()) { + return -1 + } else if (leftArray.numElements() > rightArray.numElements()) { + return 1 + } else { + return 0 + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index fbdd3a7776f50..ee435578743fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,8 +23,67 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData} import org.apache.spark.sql.types._ +import scala.beans.{BeanProperty, BeanInfo} + +@BeanInfo +private[sql] case class GroupableData(@BeanProperty data: Int) + +private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { + + override def sqlType: DataType = IntegerType + + override def serialize(obj: Any): Int = { + obj match { + case groupableData: GroupableData => groupableData.data + } + } + + override def deserialize(datum: Any): GroupableData = { + datum match { + case data: Int => GroupableData(data) + } + } + + override def userClass: Class[GroupableData] = classOf[GroupableData] + + private[spark] override def asNullable: GroupableUDT = this +} + +@BeanInfo +private[sql] case class UngroupableData(@BeanProperty data: Map[Int, Int]) + +private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { + + override def sqlType: DataType = MapType(IntegerType, IntegerType) + + override def serialize(obj: Any): MapData = { + obj match { + case groupableData: UngroupableData => + val keyArray = new GenericArrayData(groupableData.data.keys.toSeq) + val valueArray = new GenericArrayData(groupableData.data.values.toSeq) + new ArrayBasedMapData(keyArray, valueArray) + } + } + + override def deserialize(datum: Any): UngroupableData = { + datum match { + case data: MapData => + val keyArray = data.keyArray().array + val valueArray = data.valueArray().array + assert(keyArray.length == valueArray.length) + val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] + UngroupableData(mapData) + } + } + + override def userClass: Class[UngroupableData] = classOf[UngroupableData] + + private[spark] override def asNullable: UngroupableUDT = this +} + case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) @@ -103,8 +162,8 @@ class AnalysisErrorSuite extends AnalysisTest { errorTest( "sorting by unsupported column types", - listRelation.orderBy('list.asc), - "sort" :: "type" :: "array" :: Nil) + mapRelation.orderBy('map.asc), + "sort" :: "type" :: "map" :: Nil) errorTest( "non-boolean filters", @@ -171,16 +230,18 @@ class AnalysisErrorSuite extends AnalysisTest { test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + // Since we manually construct the logical plan at here and Sum only accetp + // LongType, DoubleType, and DecimalType. We use LongType as the type of a. val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + AttributeReference("a", LongType)(exprId = ExprId(2)))) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil) + assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) } test("error test for self-join") { @@ -192,28 +253,66 @@ class AnalysisErrorSuite extends AnalysisTest { assert(error.message.contains("Conflicting attributes")) } - test("aggregation can't work on binary and map types") { - val plan = - Aggregate( - AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + test("check grouping expression data types") { + def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = { + val plan = + Aggregate( + AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + LocalRelation( + AttributeReference("a", dataType)(exprId = ExprId(2)), + AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + + shouldSuccess match { + case true => + assertAnalysisSuccess(plan, true) + case false => + assertAnalysisError(plan, "expression a cannot be used as a grouping expression" :: Nil) + } + } - assertAnalysisError(plan, - "binary type expression a cannot be used in grouping expression" :: Nil) + val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", StringType, nullable = true), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new GroupableUDT()) + supportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = true) + } - val plan2 = + val unsupportedDataTypes = Seq( + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", MapType(StringType, LongType), nullable = true), + new UngroupableUDT()) + unsupportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = false) + } + } + + test("we should fail analysis when we find nested aggregate functions") { + val plan = Aggregate( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil, + Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1)))), "c")() :: Nil, LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), + AttributeReference("a", IntegerType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - assertAnalysisError(plan2, - "map type expression a cannot be used in grouping expression" :: Nil) + assertAnalysisError( + plan, + "It is not allowed to use an aggregate function in the argument of " + + "another aggregate function." :: Nil) } test("Join can't work on binary and map types") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 71d2939ecffe6..65f09b46afae1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -45,7 +45,7 @@ class AnalysisSuite extends AnalysisTest { val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) - assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) + assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved) } test("analyze project") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 40c4ae7920918..fed591fd90a9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index c9bcc68f02030..ba1866efc84e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -22,8 +22,9 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{TypeCollection, StringType} +import org.apache.spark.sql.types.{LongType, TypeCollection, StringType} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -31,7 +32,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, - 'complexField.array(StringType)) + 'arrayField.array(StringType), + 'mapField.map(StringType, LongType)) def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { @@ -89,9 +91,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(BitwiseOr('booleanField, 'booleanField), "requires integral type") assertError(BitwiseXor('booleanField, 'booleanField), "requires integral type") - assertError(MaxOf('complexField, 'complexField), + assertError(MaxOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(MinOf('complexField, 'complexField), + assertError(MinOf('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") } @@ -108,20 +110,20 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(EqualTo('intField, 'booleanField)) assertSuccess(EqualNullSafe('intField, 'booleanField)) - assertErrorForDifferingTypes(EqualTo('intField, 'complexField)) - assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField)) + assertErrorForDifferingTypes(EqualTo('intField, 'mapField)) + assertErrorForDifferingTypes(EqualNullSafe('intField, 'mapField)) assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(LessThan('complexField, 'complexField), + assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(LessThanOrEqual('complexField, 'complexField), + assertError(LessThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThan('complexField, 'complexField), + assertError(GreaterThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") - assertError(GreaterThanOrEqual('complexField, 'complexField), + assertError(GreaterThanOrEqual('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(If('intField, 'stringField, 'stringField), @@ -129,10 +131,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) assertError( - CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( - CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)), "THEN and ELSE expressions should all be same type or coercible to a common type") assertError( CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), @@ -140,15 +142,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for aggregates") { + // We use AggregateFunction directly at here because the error will be thrown from it + // instead of from AggregateExpression, which is the wrapper of an AggregateFunction. + // We will cast String to Double for sum and average assertSuccess(Sum('stringField)) - assertSuccess(SumDistinct('stringField)) assertSuccess(Average('stringField)) + assertSuccess(Min('arrayField)) - assertError(Min('complexField), "min does not support ordering on type") - assertError(Max('complexField), "max does not support ordering on type") + assertError(Min('mapField), "min does not support ordering on type") + assertError(Max('mapField), "max does not support ordering on type") assertError(Sum('booleanField), "function sum requires numeric type") - assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type") assertError(Average('booleanField), "function average requires numeric type") } @@ -182,7 +186,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'intField), "Only foldable Expression is allowed") assertError(Round('intField, 'booleanField), "requires int type") - assertError(Round('intField, 'complexField), "requires int type") + assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala index 05b870705e7ea..bc07b609a3413 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala @@ -48,4 +48,7 @@ object TestRelations { val listRelation = LocalRelation( AttributeReference("list", ArrayType(IntegerType))()) + + val mapRelation = LocalRelation( + AttributeReference("map", MapType(IntegerType, IntegerType))()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index b0dacf7f555e0..9fe64b4cf10e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,232 +17,27 @@ package org.apache.spark.sql.catalyst.encoders -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe._ +import java.util.Arrays import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{StructField, ArrayType} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ExpressionEncoderSuite extends SparkFunSuite { - - encodeDecodeTest(1) - encodeDecodeTest(1L) - encodeDecodeTest(1.toDouble) - encodeDecodeTest(1.toFloat) - encodeDecodeTest(true) - encodeDecodeTest(false) - encodeDecodeTest(1.toShort) - encodeDecodeTest(1.toByte) - encodeDecodeTest("hello") - - encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - // TODO: Support creating specific subclasses of Seq. - ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) } - - encodeDecodeTest( - OptionalData( - Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None)) - - encodeDecodeTest( - BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - encodeDecodeTest( - BoxedData(null, null, null, null, null, null, null)) - - encodeDecodeTest( - RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - encodeDecodeTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null))) - - encodeDecodeTest(("Seq[(String, String)]", - Seq(("a", "b")))) - encodeDecodeTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - encodeDecodeTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - encodeDecodeTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - encodeDecodeTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - encodeDecodeTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - // TODO: Decoding/encoding of complex maps. - ignore("complex maps") { - encodeDecodeTest(("Map[Int, (String, String)]", - Map(1 ->("a", "b")))) - } - - encodeDecodeTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - encodeDecodeTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - encodeDecodeTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - encodeDecodeTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - encodeDecodeTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array((1, 2))))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array(Array((1, 2)))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]", - Array(Array(Array(Array((1, 2))))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]", - Array(Array(Array(Array(Array((1, 2)))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - - encodeDecodeTestCustom(("Array[Array[Integer]]", - Array(Array[Integer](1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(Array(1))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Int]]]", - Array(Array(Array(Array(1)))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]", - Array(Array(Array(Array(Array(1))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - encodeDecodeTest(("Array[Byte] null", - null: Array[Byte])) - encodeDecodeTestCustom(("Array[Byte]", - Array[Byte](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Int] null", - null: Array[Int])) - encodeDecodeTestCustom(("Array[Int]", - Array[Int](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Long] null", - null: Array[Long])) - encodeDecodeTestCustom(("Array[Long]", - Array[Long](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Double] null", - null: Array[Double])) - encodeDecodeTestCustom(("Array[Double]", - Array[Double](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Float] null", - null: Array[Float])) - encodeDecodeTestCustom(("Array[Float]", - Array[Float](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Boolean] null", - null: Array[Boolean])) - encodeDecodeTestCustom(("Array[Boolean]", - Array[Boolean](true, false))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Short] null", - null: Array[Short])) - encodeDecodeTestCustom(("Array[Short]", - Array[Short](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTestCustom(("java.sql.Timestamp", - new java.sql.Timestamp(1))) - { (l, r) => l._2.toString == r._2.toString } - - encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1))) - { (l, r) => l._2.toString == r._2.toString } - - /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ - protected def encodeDecodeTest[T : TypeTag](inputData: T) = - encodeDecodeTestCustom[T](inputData)((l, r) => l == r) - - /** - * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it - * matches the original. - */ - protected def encodeDecodeTestCustom[T : TypeTag]( - inputData: T)( - c: (T, T) => Boolean) = { - test(s"encode/decode: $inputData - ${inputData.getClass.getName}") { - val encoder = try ExpressionEncoder[T]() catch { - case e: Exception => - fail(s"Exception thrown generating encoder", e) - } - val convertedData = encoder.toRow(inputData) +import org.apache.spark.sql.types.ArrayType + +abstract class ExpressionEncoderSuite extends SparkFunSuite { + protected def encodeDecodeTest[T]( + input: T, + encoder: ExpressionEncoder[T], + testName: String): Unit = { + test(s"encode/decode for $testName: $input") { + val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema).bind(schema) - val convertedBack = try boundEncoder.fromRow(convertedData) catch { + val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( s"""Exception thrown while decoding - |Converted: $convertedData + |Converted: $row |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | @@ -252,18 +47,27 @@ class ExpressionEncoderSuite extends SparkFunSuite { """.stripMargin, e) } - if (!c(inputData, convertedBack)) { + val isCorrect = (input, convertedBack) match { + case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2) + case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2) + case (b1: Array[Array[_]], b2: Array[Array[_]]) => + Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case (b1: Array[_], b2: Array[_]) => + Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case _ => input == convertedBack + } + + if (!isCorrect) { val types = convertedBack match { case c: Product => c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") case other => other.getClass.getName } - val encodedData = try { - convertedData.toSeq(encoder.schema).zip(encoder.schema).map { - case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => - a.toArray[Any](at.elementType).toSeq + row.toSeq(encoder.schema).zip(schema).map { + case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) => + a.toArray[Any](et).toSeq case (other, _) => other }.mkString("[", ",", "]") @@ -274,7 +78,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { fail( s"""Encoded/Decoded data does not match input data | - |in: $inputData + |in: $input |out: $convertedBack |types: $types | @@ -282,11 +86,10 @@ class ExpressionEncoderSuite extends SparkFunSuite { |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | - |Extract Expressions: - |$boundEncoder + |fromRow Expressions: + |${boundEncoder.fromRowExpression.treeString} """.stripMargin) - } } - + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala new file mode 100644 index 0000000000000..55821c4370684 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.sql.{Date, Timestamp} + +class FlatEncoderSuite extends ExpressionEncoderSuite { + encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") + encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") + encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") + encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") + encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") + encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") + encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") + type JDecimal = java.math.BigDecimal + // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") + + encodeDecodeTest("hello", FlatEncoder[String], "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") + + encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") + encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") + encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), + FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + FlatEncoder[Seq[Seq[String]]], "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") + encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") + encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") + encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") + encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), + FlatEncoder[Array[Array[Int]]], "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + FlatEncoder[Array[Array[String]]], "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), + FlatEncoder[Map[Int, Map[String, Int]]], "map of map") +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala new file mode 100644 index 0000000000000..bc539d62c537d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} + +case class RepeatedStruct(s: Seq[PrimitiveData]) + +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +class ProductEncoderSuite extends ExpressionEncoderSuite { + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + encodeDecodeTest( + 1 -> 10L, + ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]), + "tuple with 2 flat encoders") + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)), + ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]), + "tuple with 2 product encoders") + + encodeDecodeTest( + (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3), + ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]), + "tuple with flat encoder and product encoder") + + encodeDecodeTest( + (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)), + ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]), + "tuple with product encoder and flat encoder") + + encodeDecodeTest( + (1, (10, 100L)), + { + val intEnc = FlatEncoder[Int] + val longEnc = FlatEncoder[Long] + ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) + }, + "nested tuple encoder") + + private def productTest[T <: Product : TypeTag](input: T): Unit = { + encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index e8301e8e06b52..c868ddec1bab2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt + override def equals(that: Any): Boolean = { + if (that.isInstanceOf[ExamplePoint]) { + val e = that.asInstanceOf[ExamplePoint] + (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && + (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) + } else { + false + } + } +} + +/** + * User-defined type for [[ExamplePoint]]. + */ +class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + + override def serialize(obj: Any): GenericArrayData = { + obj match { + case p: ExamplePoint => + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: ArrayData => + new ExamplePoint(values.getDouble(0), values.getDouble(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this +} + class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() @@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite { .add("string", StringType) .add("binary", BinaryType) .add("date", DateType) - .add("timestamp", TimestampType)) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT, false)) encodeDecodeTest( new StructType() @@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite { .add("structOfArray", new StructType().add("array", arrayOfString)) .add("structOfMap", new StructType().add("map", mapOfString)) .add("structOfArrayAndMap", - new StructType().add("array", arrayOfString).add("map", mapOfString))) + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + test(s"encode/decode: arrayOfUDT") { + val schema = new StructType() + .add("arrayOfUDT", arrayOfUDT) + + val encoder = RowEncoder(schema) + + val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) + } + + test(s"encode/decode: Product") { + val schema = new StructType() + .add("structAsProduct", + new StructType() + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType)) + + val encoder = RowEncoder(schema) + + val input: Row = Row((100, "test", 0.123)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getStruct(0) == convertedBack.getStruct(0)) + } private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e323467af5f4a..002ed16dcfe7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.math._ - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{Row, RandomDataGenerator} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -49,40 +47,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { futures.foreach(Await.result(_, 10.seconds)) } - // Test GenerateOrdering for all common types. For each type, we construct random input rows that - // contain two columns of that type, then for pairs of randomly-generated rows we check that - // GenerateOrdering agrees with RowOrdering. - (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => - test(s"GenerateOrdering with $dataType") { - val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) - val genOrdering = GenerateOrdering.generate( - BoundReference(0, dataType, nullable = true).asc :: - BoundReference(1, dataType, nullable = true).asc :: Nil) - val rowType = StructType( - StructField("a", dataType, nullable = true) :: - StructField("b", dataType, nullable = true) :: Nil) - val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) - assume(maybeDataGenerator.isDefined) - val randGenerator = maybeDataGenerator.get - val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) - for (_ <- 1 to 50) { - val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] - withClue(s"a = $a, b = $b") { - assert(genOrdering.compare(a, a) === 0) - assert(genOrdering.compare(b, b) === 0) - assert(rowOrdering.compare(a, a) === 0) - assert(rowOrdering.compare(b, b) === 0) - assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) - assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) - assert( - signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), - "Generated and non-generated orderings should agree") - } - } - } - } - test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 0df673bb9fa02..c1e3c17b87102 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -231,4 +231,18 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } + + test("function dropAnyNull") { + val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1)))) + val a = create_row("a", "q") + val nullStr: String = null + checkEvaluation(drop, a, a) + checkEvaluation(drop, null, create_row("b", nullStr)) + checkEvaluation(drop, null, create_row(nullStr, nullStr)) + + val row = 'r.struct( + StructField("a", StringType, false), + StructField("b", StringType, true)).at(0) + checkEvaluation(DropAnyNull(row), null, create_row(null)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 610d39e8493cd..53c66d8a754ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -465,6 +465,42 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) } + test("to_unix_timestamp") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3) + val date1 = Date.valueOf("2015-07-24") + checkEvaluation( + ToUnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) + val t1 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + ToUnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) + checkEvaluation( + ToUnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(ToUnixTimestamp( + Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + } + test("datediff") { checkEvaluation( DateDiff(Literal(Date.valueOf("2015-07-24")), Literal(Date.valueOf("2015-07-21"))), 3) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index f33125f463e14..7b754091f4714 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -209,8 +209,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal("f5") :: Nil + private def checkJsonTuple(jt: JsonTuple, expected: InternalRow): Unit = { + assert(jt.eval(null).toSeq.head === expected) + } + test("json_tuple - hive key 1") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: jsonTupleQuery), @@ -218,7 +222,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: jsonTupleQuery), @@ -226,7 +230,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2 (mix of foldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: Literal("f1") :: NonFoldableLiteral("f2") :: @@ -238,7 +242,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: jsonTupleQuery), @@ -247,7 +251,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable json)") { - checkEvaluation( + checkJsonTuple( JsonTuple( NonFoldableLiteral( """{"f1": "value13", "f4": "value44", @@ -258,7 +262,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal( """{"f1": "value13", "f4": "value44", | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) :: @@ -273,43 +277,43 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 4 - null json") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal(null) :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - hive key 5 - null and empty fields") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "", "f5": null}""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(UTF8String.fromString(""), null, null, null, null))) } test("json_tuple - hive key 6 - invalid json (array)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("[invalid JSON string]") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (object start only)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (no object end)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"foo": "bar"""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (invalid json)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("\\") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - preserve newlines") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala new file mode 100644 index 0000000000000..7ad8657bde128 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.math._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.types._ + +class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { + + def compareArrays(a: Seq[Any], b: Seq[Any], expected: Int): Unit = { + test(s"compare two arrays: a = $a, b = $b") { + val dataType = ArrayType(IntegerType) + val rowType = StructType(StructField("array", dataType, nullable = true) :: Nil) + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + val rowA = toCatalyst(Row(a)).asInstanceOf[InternalRow] + val rowB = toCatalyst(Row(b)).asInstanceOf[InternalRow] + Seq(Ascending, Descending).foreach { direction => + val sortOrder = direction match { + case Ascending => BoundReference(0, dataType, nullable = true).asc + case Descending => BoundReference(0, dataType, nullable = true).desc + } + val expectedCompareResult = direction match { + case Ascending => signum(expected) + case Descending => -1 * signum(expected) + } + val intOrdering = new InterpretedOrdering(sortOrder :: Nil) + val genOrdering = GenerateOrdering.generate(sortOrder :: Nil) + Seq(intOrdering, genOrdering).foreach { ordering => + assert(ordering.compare(rowA, rowA) === 0) + assert(ordering.compare(rowB, rowB) === 0) + assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) + assert(signum(ordering.compare(rowB, rowA)) === -1 * expectedCompareResult) + } + } + } + } + + // Two arrays have the same size. + compareArrays(Seq[Any](), Seq[Any](), 0) + compareArrays(Seq[Any](1), Seq[Any](1), 0) + compareArrays(Seq[Any](1, 2), Seq[Any](1, 2), 0) + compareArrays(Seq[Any](1, 2, 2), Seq[Any](1, 2, 3), -1) + + // Two arrays have different sizes. + compareArrays(Seq[Any](), Seq[Any](1), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 4), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, 2), -1) + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 2, 2), 1) + + // Arrays having nulls. + compareArrays(Seq[Any](1, 2, 3), Seq[Any](1, 2, 3, null), -1) + compareArrays(Seq[Any](), Seq[Any](null), -1) + compareArrays(Seq[Any](null), Seq[Any](null), 0) + compareArrays(Seq[Any](null, null), Seq[Any](null, null), 0) + compareArrays(Seq[Any](null), Seq[Any](null, null), -1) + compareArrays(Seq[Any](null), Seq[Any](1), -1) + compareArrays(Seq[Any](null), Seq[Any](null, 1), -1) + compareArrays(Seq[Any](null, 1), Seq[Any](1, 1), -1) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 1), 0) + compareArrays(Seq[Any](1, null, 1), Seq[Any](1, null, 2), -1) + + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + { + val structType = + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true) + val arrayOfStructType = ArrayType(structType) + val complexTypes = ArrayType(IntegerType) :: structType :: arrayOfStructType :: Nil + (DataTypeTestUtils.atomicTypes ++ complexTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = InterpretedOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala new file mode 100644 index 0000000000000..9de066e99d637 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.IntegerType + +class SubexpressionEliminationSuite extends SparkFunSuite { + test("Semantic equals and hash") { + val id = ExprId(1) + val a: AttributeReference = AttributeReference("name", IntegerType)() + val b1 = a.withName("name2").withExprId(id) + val b2 = a.withExprId(id) + + assert(b1 != b2) + assert(a != b1) + assert(b1.semanticEquals(b2)) + assert(!b1.semanticEquals(a)) + assert(a.hashCode != b1.hashCode) + assert(b1.hashCode == b2.hashCode) + assert(b1.semanticHash() == b2.semanticHash()) + } + + test("Expression Equivalence - basic") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllEquivalentExprs.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + var twoB = Literal(2) + + assert(equivalence.getEquivalentExprs(oneA).isEmpty) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + assert(!equivalence.addExpr(oneA)) + assert(equivalence.getEquivalentExprs(oneA).size == 1) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.addExpr((oneA))) + assert(equivalence.getEquivalentExprs(oneA).size == 2) + + // Add B and make sure they can see each other. + assert(equivalence.addExpr(oneB)) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.getAllEquivalentExprs.size == 1) + assert(equivalence.getAllEquivalentExprs.head.size == 3) + assert(equivalence.getAllEquivalentExprs.head.contains(oneA)) + assert(equivalence.getAllEquivalentExprs.head.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addExpr(add1) + equivalence.addExpr(add2) + + assert(equivalence.getAllEquivalentExprs.size == 2) + assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) + assert(equivalence.getEquivalentExprs(add2).size == 2) + assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) + } + + test("Expression Equivalence - Trees") { + val one = Literal(1) + val two = Literal(2) + + val add = Add(one, two) + val abs = Abs(add) + val add2 = Add(add, add) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + equivalence.addExprTree(abs, true) + equivalence.addExprTree(add2, true) + + // Should only have one equivalence for `one + two` + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4) + + // Set up the expressions + // one * two, + // (one * two) * (one * two) + // sqrt( (one * two) * (one * two) ) + // (one * two) + sqrt( (one * two) * (one * two) ) + equivalence = new EquivalentExpressions + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + equivalence.addExprTree(mul, true) + equivalence.addExprTree(mul2, true) + equivalence.addExprTree(sqrt, true) + equivalence.addExprTree(sum, true) + + // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3) + assert(equivalence.getEquivalentExprs(mul).size == 3) + assert(equivalence.getEquivalentExprs(mul2).size == 3) + assert(equivalence.getEquivalentExprs(sqrt).size == 2) + assert(equivalence.getEquivalentExprs(sum).size == 1) + + // Some expressions inspired by TPCH-Q1 + // sum(l_quantity) as sum_qty, + // sum(l_extendedprice) as sum_base_price, + // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + // avg(l_extendedprice) as avg_price, + // avg(l_discount) as avg_disc + equivalence = new EquivalentExpressions + val quantity = Literal(1) + val price = Literal(1.1) + val discount = Literal(.24) + val tax = Literal(0.1) + equivalence.addExprTree(quantity, false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) + equivalence.addExprTree( + Multiply( + Multiply(price, Subtract(Literal(1), discount)), + Add(Literal(1), tax)), false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(discount, false) + // quantity, price, discount and (price * (1 - discount)) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addExpr(sum) + equivalence.addExpr(sum) + assert(equivalence.getAllEquivalentExprs.isEmpty) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index e67606288f514..8aaefa84937c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -162,7 +162,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -170,7 +170,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1.0) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ed810a12808f0..0290fafe879f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -68,7 +68,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group") { val originalQuery = testRelation - .groupBy('a)('a, Count('b)) + .groupBy('a)('a, count('b)) .select('a) val optimized = Optimize.execute(originalQuery.analyze) @@ -84,7 +84,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group with alias") { val originalQuery = testRelation - .groupBy('a)('a as 'c, Count('b)) + .groupBy('a)('a as 'c, count('b)) .select('c) val optimized = Optimize.execute(originalQuery.analyze) @@ -656,7 +656,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filter when filter on group by expression") { val originalQuery = testRelation - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .select('a, 'c) .where('a === 2) @@ -664,7 +664,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .where('a === 2) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .analyze comparePlans(optimized, correctAnswer) } @@ -672,7 +672,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: don't push down filter when filter not on group by expression") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) val optimized = Optimize.execute(originalQuery.analyze) @@ -683,7 +683,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filters partially which are subset of group by expressions") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L && 'a === 3) val optimized = Optimize.execute(originalQuery.analyze) @@ -691,7 +691,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a, 'b) .where('a === 3) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 64d15e6b910c1..60d45422bc9b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -358,7 +358,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(getSeconds(c.getTimeInMillis * 1000) === 9) } - test("hours / miniute / seconds") { + test("hours / minutes / seconds") { Seq(Timestamp.valueOf("2015-06-11 10:12:35.789"), Timestamp.valueOf("2015-06-11 20:13:40.789"), Timestamp.valueOf("1900-06-11 12:14:50.789"), diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c96855e261ee8..9fd6b5a07ec86 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -110,6 +110,11 @@ mockito-core test + + org.apache.xbean + xbean-asm5-shaded + test + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c32c93897ce0b..82e9cd7f50a31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression + import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ @@ -39,10 +41,32 @@ private[sql] object Column { } /** - * A [[Column]] where an [[Encoder]] has been given for the expected return type. + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. * @since 1.6.0 + * @tparam T The input type expected for this expression. Can be `Any` if the expression is type + * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U The output type of this column. */ -class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr) +class TypedColumn[-T, U]( + expr: Expression, + private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) { + + /** + * Inserts the specific input type and schema into any expressions that are expected to operate + * on a decoded object. + */ + private[sql] def withInputType( + inputEncoder: ExpressionEncoder[_], + schema: Seq[Attribute]): TypedColumn[T, U] = { + val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] + new TypedColumn[T, U] (expr transform { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(boundEncoder), + children = schema) + }, encoder) + } +} /** * :: Experimental :: @@ -70,6 +94,25 @@ class Column(protected[sql] val expr: Expression) extends Logging { /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) + /** + * Returns the expression for this column either with an existing or auto assigned name. + */ + private[sql] def named: NamedExpression = expr match { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + + case expr: NamedExpression => expr + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. + case explode: Explode => MultiAlias(explode, Nil) + case jt: JsonTuple => MultiAlias(jt, Nil) + + case expr: Expression => Alias(expr, expr.prettyString)() + } + override def toString: String = expr.prettyString override def equals(that: Any): Boolean = that match { @@ -85,7 +128,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) /** * Extracts a value or values from a complex type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f2d4db5550273..3ba4ba18d2122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -23,7 +23,6 @@ import java.util.Properties import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory import org.apache.commons.lang3.StringUtils @@ -34,11 +33,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -115,7 +114,8 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable { + @DeveloperApi @transient val queryExecution: QueryExecution) + extends Queryable with Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. @@ -233,15 +233,6 @@ class DataFrame private[sql]( sb.toString() } - override def toString: String = { - try { - schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") - } catch { - case NonFatal(e) => - s"Invalid tree; ${e.getMessage}:\n$queryExecution" - } - } - /** * Returns the object itself. * @group basic @@ -744,18 +735,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def select(cols: Column*): DataFrame = withPlan { - val namedExpressions = cols.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) - case Column(expr: NamedExpression) => expr - // Leave an unaliased explode with an empty list of names since the analyzer will generate the - // correct defaults after the nested expression's type has been resolved. - case Column(explode: Explode) => MultiAlias(explode, Nil) - case Column(expr: Expression) => Alias(expr, expr.prettyString)() - } - Project(namedExpressions.toSeq, logicalPlan) + Project(cols.map(_.named), logicalPlan) } /** @@ -1338,7 +1318,7 @@ class DataFrame private[sql]( if (groupColExprIds.contains(attr.exprId)) { attr } else { - Alias(First(attr), attr.name)() + Alias(new First(attr).toAggregateExpression(), attr.name)() } } Aggregate(groupCols, aggCols, logicalPlan) @@ -1381,11 +1361,11 @@ class DataFrame private[sql]( // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( - "count" -> Count, - "mean" -> Average, - "stddev" -> StddevSamp, - "min" -> Min, - "max" -> Max) + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList @@ -1478,18 +1458,54 @@ class DataFrame private[sql]( /** * Returns the first `n` rows in the [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * * @group action * @since 1.3.0 */ def take(n: Int): Array[Row] = head(n) + /** + * Returns the first `n` rows in the [[DataFrame]] as a list. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.6.0 + */ + def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) + /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * * @group action * @since 1.3.0 */ def collect(): Array[Row] = collect(needCallback = true) + /** + * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.3.0 + */ + def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + withNewExecutionId { + java.util.Arrays.asList(rdd.collect() : _*) + } + } + private def collect(needCallback: Boolean): Array[Row] = { def execute(): Array[Row] = withNewExecutionId { queryExecution.executedPlan.executeCollectPublic() @@ -1502,17 +1518,6 @@ class DataFrame private[sql]( } } - /** - * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => - withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) - } - } - /** * Returns the number of rows in the [[DataFrame]]. * @group action diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6a194a443ab17..5872fbded3833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.execution.datasources.json.JSONRelation +import org.apache.spark.sql.execution.datasources.json.{JSONOptions, JSONRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType @@ -227,6 +227,15 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *

  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • + *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • + *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • + *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
  • + *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
  • + * * @param path input path * @since 1.4.0 */ @@ -255,16 +264,13 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def json(jsonRDD: RDD[String]): DataFrame = { - val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble - val primitivesAsString = extraOptions.getOrElse("primitivesAsString", "false").toBoolean sqlContext.baseRelationToDataFrame( new JSONRelation( Some(jsonRDD), - samplingRatio, - primitivesAsString, - userSpecifiedSchema, - None, - None)(sqlContext) + maybeDataSchema = userSpecifiedSchema, + maybePartitionSpec = None, + userDefinedPartitionColumns = None, + parameters = extraOptions.toMap)(sqlContext) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7887e559a3025..e63a4d5e8b10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,8 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.sources.HadoopFsRelation @@ -167,17 +167,38 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { - val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite + + // A partitioned relation's schema can be different from the input logicalPlan, since + // partition columns are all moved after data columns. We Project to adjust the ordering. + // TODO: this belongs to the analyzer. + val input = normalizedParCols.map { parCols => + val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => + parCols.contains(attr.name) + } + Project(inputDataCols ++ inputPartCols, df.logicalPlan) + }.getOrElse(df.logicalPlan) + df.sqlContext.executePlan( InsertIntoTable( UnresolvedRelation(tableIdent), partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, + input, overwrite, ifNotExists = false)).toRdd } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => + parCols.map { col => + df.logicalPlan.output + .map(_.name) + .find(df.sqlContext.analyzer.resolver(_, col)) + .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + + s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + } + } + /** * Saves the content of the [[DataFrame]] as the specified table. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fecbdac9a6004..4cc3aa2465f2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,14 +21,14 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType /** @@ -62,21 +62,30 @@ import org.apache.spark.sql.types.StructType class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, - unresolvedEncoder: Encoder[T]) extends Serializable { + tEncoder: Encoder[T]) extends Queryable with Serializable { + + /** + * An unresolved version of the internal encoder for the type of this dataset. This one is marked + * implicit so that we can use it when constructing new [[Dataset]] objects that have the same + * object type (that will be possibly resolved to a different schema). + */ + private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { - case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output) - case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") - } + private[sql] val resolvedTEncoder: ExpressionEncoder[T] = + unresolvedTEncoder.resolve(queryExecution.analyzed.output) - private implicit def classTag = encoder.clsTag + private implicit def classTag = resolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) - /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ - def schema: StructType = encoder.schema + /** + * Returns the schema of the encoded form of the objects in this [[Dataset]]. + * + * @since 1.6.0 + */ + def schema: StructType = resolvedTEncoder.schema /* ************* * * Conversions * @@ -103,6 +112,7 @@ class Dataset[T] private[sql]( /** * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have * the same name after two Datasets have been joined. + * @since 1.6.0 */ def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) @@ -128,7 +138,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = encoderFor[T] + val tEnc = resolvedTEncoder val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => val bound = tEnc.bind(input) @@ -166,8 +176,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. * @since 1.6.0 */ - def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = - filter(t => func.call(t).booleanValue()) + def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) /** * (Scala-specific) @@ -181,7 +190,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = map(t => func.call(t))(encoder) /** @@ -190,7 +199,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - new Dataset( + new Dataset[U]( sqlContext, MapPartitions[T, U]( func, @@ -205,10 +214,8 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def mapPartitions[U]( - f: FlatMapFunction[java.util.Iterator[T], U], - encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala mapPartitions(func)(encoder) } @@ -248,7 +255,7 @@ class Dataset[T] private[sql]( * Runs `func` on each element of this Dataset. * @since 1.6.0 */ - def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) + def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** * (Scala-specific) @@ -262,7 +269,7 @@ class Dataset[T] private[sql]( * Runs `func` on each partition of this Dataset. * @since 1.6.0 */ - def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = + def foreachPartition(func: ForeachPartitionFunction[T]): Unit = foreachPartition(it => func.call(it.asJava)) /* ************* * @@ -271,7 +278,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -279,33 +286,11 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ - def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) - - /** - * (Scala-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * - * This behaves somewhat differently than the fold operations implemented for non-distributed - * collections in functional languages like Scala. This fold operation may be applied to - * partitions individually, and then those results will be folded into the final result. - * If op is not commutative, then the result may differ from that of a fold applied to a - * non-distributed collection. - * @since 1.6.0 - */ - def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) - - /** - * (Java-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * @since 1.6.0 - */ - def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) + def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** * (Scala-specific) @@ -314,12 +299,12 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = queryExecution.analyzed - val withGroupingKey = AppendColumn(func, inputPlan) + val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( - encoderFor[K].resolve(withGroupingKey.newColumns), - encoderFor[T].bind(inputPlan.output), + encoderFor[K], + encoderFor[T], executed, inputPlan.output, withGroupingKey.newColumns) @@ -351,7 +336,7 @@ class Dataset[T] private[sql]( * Returns a [[GroupedDataset]] where the data is grouped by the given key function. * @since 1.6.0 */ - def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = groupBy(f.call(_))(encoder) /* ****************** * @@ -367,7 +352,7 @@ class Dataset[T] private[sql]( */ // Copied from Dataframe to make sure we don't have invalid overloads. @scala.annotation.varargs - def select(cols: Column*): DataFrame = toDF().select(cols: _*) + protected def select(cols: Column*): DataFrame = toDF().select(cols: _*) /** * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. @@ -378,8 +363,14 @@ class Dataset[T] private[sql]( * }}} * @since 1.6.0 */ - def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { + new Dataset[U1]( + sqlContext, + Project( + c1.withInputType( + resolvedTEncoder, + queryExecution.analyzed.output).named :: Nil, + logicalPlan)) } /** @@ -387,17 +378,12 @@ class Dataset[T] private[sql]( * code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. */ - protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = { - val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } - val unresolvedPlan = Project(aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) - // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a.toAttribute).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output) - } + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named) + val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } @@ -405,7 +391,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** @@ -413,9 +399,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** @@ -423,10 +409,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** @@ -434,11 +420,11 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4, U5]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4], - c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /* **************** * @@ -462,8 +448,7 @@ class Dataset[T] private[sql]( * and thus is not affected by a custom `equals` function defined on `T`. * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Intersect) + def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect) /** * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] @@ -473,8 +458,7 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analagous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. @@ -507,23 +491,18 @@ class Dataset[T] private[sql]( val left = this.logicalPlan val right = other.logicalPlan - val leftData = this.encoder match { + val leftData = this.unresolvedTEncoder match { case e if e.flat => Alias(left.output.head, "_1")() case _ => Alias(CreateStruct(left.output), "_1")() } - val rightData = other.encoder match { + val rightData = other.unresolvedTEncoder match { case e if e.flat => Alias(right.output.head, "_2")() case _ => Alias(CreateStruct(right.output), "_2")() } - val leftEncoder = - if (encoder.flat) encoder else encoder.nested(leftData.toAttribute) - val rightEncoder = - if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple( - leftEncoder, - rightEncoder.rebind(right.output, left.output ++ right.output)) + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => Project( leftData :: rightData :: Nil, @@ -542,27 +521,47 @@ class Dataset[T] private[sql]( def first(): T = rdd.first() /** - * Collects the elements to an Array. + * Returns an array that contains all the elements in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ def collect(): Array[T] = rdd.collect() /** - * (Java-specific) - * Collects the elements to a Java list. + * Returns an array that contains all the elements in this [[Dataset]]. * - * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at - * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method - * instead and keep the generic type for result. + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = - rdd.collect().toSeq.asJava + def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava - /** Returns the first `num` elements of this [[Dataset]] as an Array. */ + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ def take(num: Int): Array[T] = rdd.take(num) + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ + def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /* ******************** * * Internal Functions * * ******************** */ @@ -570,7 +569,7 @@ class Dataset[T] private[sql]( private[sql] def logicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder) + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) private[sql] def withPlan[R : Encoder]( other: Dataset[_])( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index f9eab5c2e965b..63dd7fbcbe9e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,10 +21,11 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} -import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} +import org.apache.spark.sql.types.{StringType, NumericType} /** @@ -49,14 +50,8 @@ class GroupedData protected[sql]( aggExprs } - val aliasedAgg = aggregates.map { - // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we - // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to - // make it a NamedExpression. - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } + val aliasedAgg = aggregates.map(alias) + groupType match { case GroupedData.GroupByType => DataFrame( @@ -67,10 +62,23 @@ class GroupedData protected[sql]( case GroupedData.CubeType => DataFrame( df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + case GroupedData.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + DataFrame( + df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + private[this] def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { val columnExprs = if (colNames.isEmpty) { @@ -88,30 +96,28 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map(f)) + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) } private[this] def strToExpr(expr: String): (Expression => Expression) = { - expr.toLowerCase match { - case "avg" | "average" | "mean" => Average - case "max" => Max - case "min" => Min - case "stddev" | "std" => StddevSamp - case "stddev_pop" => StddevPop - case "stddev_samp" => StddevSamp - case "variance" => VarianceSamp - case "var_pop" => VariancePop - case "var_samp" => VarianceSamp - case "sum" => Sum - case "skewness" => Skewness - case "kurtosis" => Kurtosis - case "count" | "size" => - // Turn count(*) into count(1) - (inputExpr: Expression) => inputExpr match { - case s: Star => Count(Literal(1)) - case _ => Count(inputExpr) - } + val exprToFunc: (Expression => Expression) = { + (inputExpr: Expression) => expr.toLowerCase match { + // We special handle a few cases that have alias that are not in function registry. + case "avg" | "average" | "mean" => + UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) + case "stddev" | "std" => + UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) + // Also special handle count because we need to take care count(*). + case "count" | "size" => + // Turn count(*) into count(1) + inputExpr match { + case s: Star => Count(Literal(1)).toAggregateExpression() + case _ => Count(inputExpr).toAggregateExpression() + } + case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) + } } + (inputExpr: Expression) => exprToFunc(inputExpr) } /** @@ -213,7 +219,7 @@ class GroupedData protected[sql]( * * @since 1.3.0 */ - def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. @@ -274,6 +280,77 @@ class GroupedData protected[sql]( def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } + + /** + * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified + * aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) + * // Or without specifying column values + * df.groupBy($"year").pivot($"course").agg(sum($"earnings")) + * }}} + * @param pivotColumn Column to pivot + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { + case _: GroupedData.PivotType => + throw new UnsupportedOperationException("repeated pivots are not supported") + case GroupedData.GroupByType => + val pivotValues = if (values.nonEmpty) { + values.map { + case Column(literal: Literal) => literal + case other => + throw new UnsupportedOperationException( + s"The values of a pivot must be literals, found $other") + } + } else { + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .sort(pivotColumn) + .map(_.get(0)) + .take(maxValues + 1) + .map(Literal(_)).toSeq + if (values.length > maxValues) { + throw new RuntimeException( + s"The pivot column $pivotColumn has more than $maxValues distinct values, " + + "this could indicate an error. " + + "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + + s"to at least the number of distinct values of the pivot column.") + } + values + } + new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) + case _ => + throw new UnsupportedOperationException("pivot is only supported after a groupBy") + } + + /** + * Pivots a column of the current [[DataFrame]] and preform the specified aggregation. + * {{{ + * // Compute the sum of earnings for each year by course with each course as a separate column + * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") + * // Or without specifying column values + * df.groupBy("year").pivot("course").sum("earnings") + * }}} + * @param pivotColumn Column to pivot + * @param values Optional list of values of pivotColumn that will be translated to columns in the + * output data frame. If values are not provided the method with do an immediate + * call to .distinct() on the pivot column. + * @since 1.6.0 + */ + @scala.annotation.varargs + def pivot(pivotColumn: String, values: Any*): GroupedData = { + val resolvedPivotColumn = Column(df.resolve(pivotColumn)) + pivot(resolvedPivotColumn, values.map(functions.lit): _*) + } } @@ -308,4 +385,9 @@ private[sql] object GroupedData { * To indicate it's the ROLLUP */ private[sql] object RollupType extends GroupType + + /** + * To indicate it's the PIVOT + */ + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b2803d5a9a1e3..ebcf4c8bfe7e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql -import java.util.{Iterator => JIterator} + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} -import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} +import org.apache.spark.api.java.function._ +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -41,26 +40,21 @@ import org.apache.spark.sql.execution.QueryExecution */ @Experimental class GroupedDataset[K, T] private[sql]( - private val kEncoder: Encoder[K], - private val tEncoder: Encoder[T], - queryExecution: QueryExecution, + kEncoder: Encoder[K], + tEncoder: Encoder[T], + val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } + // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders + // when constructing new logical plans that will operate on the output of the current + // queryexecution. - private implicit val tEnc = tEncoder match { - case e: ExpressionEncoder[T] => e.resolve(dataAttributes) - case other => - throw new UnsupportedOperationException("Only expression encoders are currently supported") - } + private implicit val unresolvedKEncoder = encoderFor(kEncoder) + private implicit val unresolvedTEncoder = encoderFor(tEncoder) - /** Encoders for built in aggregations. */ - private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) + private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -76,7 +70,7 @@ class GroupedDataset[K, T] private[sql]( def asKey[L : Encoder]: GroupedDataset[L, T] = new GroupedDataset( encoderFor[L], - tEncoder, + unresolvedTEncoder, queryExecution, dataAttributes, groupingAttributes) @@ -102,16 +96,53 @@ class GroupedDataset[K, T] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. */ - def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { + def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, - MapGroups(f, groupingAttributes, logicalPlan)) + MapGroups( + f, + resolvedKEncoder, + resolvedTEncoder, + groupingAttributes, + logicalPlan)) } - def mapGroups[U]( - f: JFunction2[K, JIterator[T], JIterator[U]], - encoder: Encoder[U]): Dataset[U] = { - mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { + flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + */ + def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) + flatMap(func) + } + + def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { + map((key, data) => f.call(key, data.asJava))(encoder) + } + + /** + * Reduces the elements of each group of data using the specified binary function. + * The given function must be commutative and associative or the result may be non-deterministic. + */ + def reduce(f: (T, T) => T): Dataset[(K, T)] = { + val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f)) + + implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder) + flatMap(func) + } + + def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = { + reduce(f.call _) } // To ensure valid overloading. @@ -124,68 +155,60 @@ class GroupedDataset[K, T] private[sql]( * that cast appropriately for the user facing interface. * TODO: does not handle aggrecations that return nonflat results, */ - protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = { - val aliases = (groupingAttributes ++ columns.map(_.expr)).map { - case u: UnresolvedAttribute => UnresolvedAlias(u) - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - - val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) - - val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) - - // Rebind the encoders to the nested schema that will be produced by the aggregation. - val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map { - case (e: ExpressionEncoder[_], a) if !e.flat => - e.nested(a).resolve(execution.analyzed.output) - case (e, a) => - e.unbind(a :: Nil).resolve(execution.analyzed.output) - } - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { + val encoders = columns.map(_.encoder) + val namedColumns = + columns.map( + _.withInputType(resolvedTEncoder, dataAttributes).named) + val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val execution = new QueryExecution(sqlContext, aggregate) + + new Dataset( + sqlContext, + execution, + ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) } /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. */ - def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]] + def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]] + def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]] + def agg[U1, U2, U3]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3, A4]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3], - col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]] + def agg[U1, U2, U3, U4]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3], + col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long]) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will @@ -195,8 +218,8 @@ class GroupedDataset[K, T] private[sql]( */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { - implicit def uEnc: Encoder[U] = other.tEncoder + f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + implicit def uEnc: Encoder[U] = other.unresolvedTEncoder new Dataset[R]( sqlContext, CoGroup( @@ -209,7 +232,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R]( other: GroupedDataset[K, U], - f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + f: CoGroupFunction[K, T, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ed8b634ad5630..f40e603cd1939 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -252,20 +252,9 @@ private[spark] object SQLConf { "not be provided to ExchangeCoordinator.", isPublic = false) - val TUNGSTEN_ENABLED = booleanConf("spark.sql.tungsten.enabled", + val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled", defaultValue = Some(true), - doc = "When true, use the optimized Tungsten physical execution backend which explicitly " + - "manages memory and dynamically generates bytecode for expression evaluation.") - - val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + - " a specific query.", - isPublic = false) - - val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(true), // use TUNGSTEN_ENABLED as default - doc = "When true, use the new optimized Tungsten physical execution backend.", + doc = "When true, common subexpressions will be eliminated.", isPublic = false) val DIALECT = stringConf( @@ -364,12 +353,6 @@ private[spark] object SQLConf { defaultValue = Some(5 * 60), doc = "Timeout in seconds for the broadcast wait time in broadcast joins.") - // Options that control which operators can be chosen by the query planner. These should be - // considered hints and may be ignored by future versions of Spark SQL. - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", - defaultValue = Some(true), - doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") - // This is only used for the thriftserver val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", doc = "Set a Fair Scheduler pool for a JDBC client session") @@ -448,8 +431,12 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) - val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", - defaultValue = Some(true), doc = "") + val DATAFRAME_PIVOT_MAX_VALUES = intConf( + "spark.sql.pivotMaxValues", + defaultValue = Some(10000), + doc = "When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error." + ) val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", defaultValue = Some(true), @@ -457,9 +444,26 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) + val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = + booleanConf("spark.sql.specializeSingleDistinctAggPlanning", + defaultValue = Some(true), + isPublic = false, + doc = "When true, if a query only has a single distinct column and it has " + + "grouping expressions, we will use our planner rule to handle this distinct " + + "column (other cases are handled by DistinctAggregationRewriter). " + + "When false, we will always use DistinctAggregationRewriter to plan " + + "aggregation queries with DISTINCT keyword. This is an internal flag that is " + + "used to benchmark the performance impact of using DistinctAggregationRewriter to " + + "plan aggregation queries with a single distinct column.") + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2" + val TUNGSTEN_ENABLED = "spark.sql.tungsten.enabled" + val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" } } @@ -524,15 +528,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def nativeView: Boolean = getConf(NATIVE_VIEW) - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) - - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) - - private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) + private[spark] def subexpressionEliminationEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_ENABLED) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) @@ -575,6 +574,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = + getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5598731af5fcc..cd1fdc4edb39d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference - import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag @@ -34,7 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6da46a5f7ef9a..8471eea1b7d9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,17 +37,21 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]() + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] - implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true) - implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) - implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true) - implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true) - implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true) - implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true) - implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) - implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) + implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] + implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] + implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] + implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] + implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] + implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] + implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] + implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + /** + * Creates a [[Dataset]] from an RDD. + * @since 1.6.0 + */ implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(rdd)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 7eb1ad7cd8198..2cface61e59c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -125,7 +125,7 @@ private[sql] case class InMemoryRelation( private def buildBuffers(): Unit = { val output = child.output - val cached = child.execute().mapPartitions { rowIterator => + val cached = child.execute().mapPartitionsInternal { rowIterator => new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => @@ -292,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( val relOutput = relation.output val buffers = relation.cachedColumnBuffers - buffers.mapPartitions { cachedBatchIterator => + buffers.mapPartitionsInternal { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala deleted file mode 100644 index 6f3f1bd97ad52..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.HashMap - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. - */ -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def requiredChildDistribution: List[Distribution] = { - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - /** - * An aggregate that needs to be computed for each row in a group. - * - * @param unbound Unbound version of this aggregate, used for result substitution. - * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. - * @param resultAttribute An attribute used to refer to the result of this aggregate in the final - * output. - */ - case class ComputedAggregate( - unbound: AggregateExpression1, - aggregate: AggregateExpression1, - resultAttribute: AttributeReference) - - /** A list of aggregates that need to be computed for each group. */ - private[this] val computedAggregates = aggregateExpressions.flatMap { agg => - agg.collect { - case a: AggregateExpression1 => - ComputedAggregate( - a, - BindReferences.bindReference(a, child.output), - AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) - } - }.toArray - - /** The schema of the result of all aggregate evaluations */ - private[this] val computedSchema = computedAggregates.map(_.resultAttribute) - - /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { - val buffer = new Array[AggregateFunction1](computedAggregates.length) - var i = 0 - while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() - i += 1 - } - buffer - } - - /** Named attributes used to substitute grouping attributes into the final result. */ - private[this] val namedGroups = groupingExpressions.map { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute - } - - /** - * A map of substitutions that are used to insert the aggregate expressions and grouping - * expression into the final result expression. - */ - private[this] val resultMap = - (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap - - /** - * Substituted version of aggregateExpressions expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. - */ - private[this] val resultExpressions = aggregateExpressions.map { agg => - agg.transform { - case e: Expression if resultMap.contains(e) => resultMap(e) - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") - val numOutputRows = longMetric("numOutputRows") - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 - } - - numOutputRows += 1 - Iterator(resultProjection(aggregateResults)) - } - } else { - child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) - } - - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) - i += 1 - } - } - - new Iterator[InternalRow] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new InterpretedMutableProjection( - resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow - - override final def hasNext: Boolean = hashTableIter.hasNext - - override final def next(): InternalRow = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue - numOutputRows += 1 - - var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) - i += 1 - } - resultProjection(joinedRow(aggregateResults, currentGroup)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index a4ce328c1a9eb..62cbc518e02af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -44,23 +44,20 @@ case class Exchange( override def nodeName: String = { val extraInfo = coordinator match { case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated => - "Shuffle" + s"(coordinator id: ${System.identityHashCode(coordinator)})" case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated => - "May shuffle" - case None => "Shuffle without coordinator" + s"(coordinator id: ${System.identityHashCode(coordinator)})" + case None => "" } val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" - s"$simpleNodeName($extraInfo)" + s"$simpleNodeName$extraInfo" } /** * Returns true iff we can support the data type, and we are not doing range partitioning. */ - private lazy val tungstenMode: Boolean = { - unsafeEnabled && codegenEnabled && GenerateUnsafeProjection.canSupport(child.schema) && - !newPartitioning.isInstanceOf[RangePartitioning] - } + private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] override def outputPartitioning: Partitioning = newPartitioning @@ -171,7 +168,7 @@ case class Exchange( case RangePartitioning(sortingExpressions, numPartitions) => // Internally, RangePartitioner runs a job on the RDD that samples keys to compute // partition bounds. To get accurate samples, we need to copy the mutable keys. - val rddForSampling = rdd.mapPartitions { iter => + val rddForSampling = rdd.mapPartitionsInternal { iter => val mutablePair = new MutablePair[InternalRow, Null]() iter.map(row => mutablePair.update(row.copy(), null)) } @@ -203,12 +200,12 @@ case class Exchange( } val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { if (needToCopyObjectsBeforeShuffle(part, serializer)) { - rdd.mapPartitions { iter => + rdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitions { iter => + rdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } @@ -478,10 +475,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { - sqlContext.planner.BasicOperators.getSortOperator( - requiredOrdering, - global = false, - child) + Sort(requiredOrdering, global = false, child = child) } else { child } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 7a466cf6a0a94..62620ec642c78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -106,29 +106,9 @@ private[sql] object PhysicalRDD { def createFromDataSource( output: Seq[Attribute], rdd: RDD[InternalRow], - relation: BaseRelation): PhysicalRDD = { - PhysicalRDD(output, rdd, relation.toString, relation.isInstanceOf[HadoopFsRelation]) + relation: BaseRelation, + extraInformation: String = ""): PhysicalRDD = { + PhysicalRDD(output, rdd, relation.toString + extraInformation, + relation.isInstanceOf[HadoopFsRelation]) } } - -/** Logical plan node for scanning data from a local collection. */ -private[sql] -case class LogicalLocalTable(output: Seq[Attribute], rows: Seq[InternalRow])(sqlContext: SQLContext) - extends LogicalPlan with MultiInstanceRelation { - - override def children: Seq[LogicalPlan] = Nil - - override def newInstance(): this.type = - LogicalLocalTable(output.map(_.newInstance()), rows)(sqlContext).asInstanceOf[this.type] - - override def sameResult(plan: LogicalPlan): Boolean = plan match { - case LogicalRDD(_, otherRDD) => rows == rows - case _ => false - } - - @transient override lazy val statistics: Statistics = Statistics( - // TODO: Improve the statistics estimation. - // This is made small enough so it can be broadcasted. - sizeInBytes = sqlContext.conf.autoBroadcastJoinThreshold - 1 - ) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 55e95769d3faa..91530bd63798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -45,6 +45,9 @@ case class Expand( override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + private[this] val projection = { if (outputsUnsafeRows) { (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 78e33d9f233a6..54b8cb58285c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -59,7 +59,7 @@ case class Generate( protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow @@ -79,7 +79,7 @@ case class Generate( } } } else { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => iter.flatMap(row => boundGenerator.eval(row)) ++ LazyIterator(() => boundGenerator.terminate()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index c2142d03f422b..5da5aea17e25b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import com.google.common.annotations.VisibleForTesting - import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -33,7 +31,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { - @VisibleForTesting def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed) lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical) @@ -80,7 +77,6 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} - |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} """.stripMargin.trim } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala new file mode 100644 index 0000000000000..9ca383896a09b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.types.StructType + +import scala.util.control.NonFatal + +/** A trait that holds shared code between DataFrames and Datasets. */ +private[sql] trait Queryable { + def schema: StructType + def queryExecution: QueryExecution + + override def toString: String = { + try { + schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 1a3832a698b61..24207cb46fd29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -17,68 +17,22 @@ package org.apache.spark.sql.execution +import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// This file defines various sort operators. -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - * Performs a sort, spilling to disk as needed. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( - TaskContext.get(), ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r.copy(), null))) - val baseIterator = sorter.iterator.map(_._1) - val context = TaskContext.get() - context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) - context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) - // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} /** - * Optimized version of [[Sort]] that operates on binary data (implemented as part of - * Project Tungsten). + * Performs (external) sorting. * * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ - -case class TungstenSort( +case class Sort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, @@ -107,7 +61,7 @@ case class TungstenSort( val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val ordering = newOrdering(sortOrder, childOutput) // The comparator for comparing prefix @@ -143,14 +97,4 @@ case class TungstenSort( sortedIterator } } - -} - -object TungstenSort { - /** - * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. - */ - def supportsSchema(schema: StructType): Boolean = { - UnsafeExternalRowSorter.supportsSchema(schema) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 8bb293ae87e64..534a3bcb8364d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -54,15 +54,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def sparkContext = sqlContext.sparkContext // sqlContext will be null when we are being deserialized on the slaves. In this instance - // the value of codegenEnabled/unsafeEnabled will be set by the desserializer after the + // the value of subexpressionEliminationEnabled will be set by the desserializer after the // constructor has run. - val codegenEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.codegenEnabled - } else { - false - } - val unsafeEnabled: Boolean = if (sqlContext != null) { - sqlContext.conf.unsafeEnabled + val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.subexpressionEliminationEnabled } else { false } @@ -226,87 +221,52 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[this] def isTesting: Boolean = sys.props.contains("spark.testing") - protected def newProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) - } - } - protected def newMutableProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): () => MutableProjection = { - log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled) { - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - } else { - () => new InterpretedMutableProjection(expressions, inputSchema) + expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } } } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled) { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } - } else { - InterpretedPredicate.create(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } } } protected def newOrdering( - order: Seq[SortOrder], - inputSchema: Seq[Attribute]): Ordering[InternalRow] = { - if (codegenEnabled) { - try { - GenerateOrdering.generate(order, inputSchema) - } catch { - case e: Exception => - if (isTesting) { - throw e - } else { - log.error("Failed to generate ordering, fallback to interpreted", e) - new InterpretedOrdering(order, inputSchema) - } - } - } else { - new InterpretedOrdering(order, inputSchema) + order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = { + try { + GenerateOrdering.generate(order, inputSchema) + } catch { + case e: Exception => + if (isTesting) { + throw e + } else { + log.error("Failed to generate ordering, fallback to interpreted", e) + new InterpretedOrdering(order, inputSchema) + } } } + /** * Creates a row ordering for the given schema, in natural ascending order. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 0f98fe88b2101..6e9a4df828246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -18,19 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.DataSourceStrategy -@Experimental class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val sparkContext: SparkContext = sqlContext.sparkContext - def codegenEnabled: Boolean = sqlContext.conf.codegenEnabled - - def unsafeEnabled: Boolean = sqlContext.conf.unsafeEnabled - def numPartitions: Int = sqlContext.conf.numShufflePartitions def strategies: Seq[Strategy] = @@ -38,7 +32,6 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { DataSourceStrategy :: DDLStrategy :: TakeOrderedAndProject :: - HashAggregation :: Aggregation :: LeftSemiJoin :: EquiJoinSelection :: @@ -69,7 +62,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = + val filterCondition: Option[Expression] = prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index b19ad4f1c563e..8317f648ccb4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -22,19 +22,16 @@ import java.util.{HashMap => JavaHashMap} import scala.reflect.ClassTag -import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair -import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.{SparkConf, SparkEnv} + private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { val kryo = super.newKryo() @@ -43,16 +40,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) - kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], - new HyperLogLogSerializer) kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer) kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer) - // Specific hashsets must come first TODO: Move to core. - kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) - kryo.register(classOf[LongHashSet], new LongHashSetSerializer) - kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], - new OpenHashSetSerializer) kryo.register(classOf[Decimal]) kryo.register(classOf[JavaHashMap[_, _]]) @@ -62,7 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co } private[execution] class KryoResourcePool(size: Int) - extends ResourcePool[SerializerInstance](size) { + extends ResourcePool[SerializerInstance](size) { val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) @@ -116,92 +106,3 @@ private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] { new java.math.BigDecimal(input.readString()) } } - -private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { - def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) { - val bytes = hyperLogLog.getBytes() - output.writeInt(bytes.length) - output.writeBytes(bytes) - } - - def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = { - val length = input.readInt() - val bytes = input.readBytes(length) - HyperLogLog.Builder.build(bytes) - } -} - -private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { - def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val row = iterator.next() - rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = { - val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] - val numItems = input.readInt() - val set = new OpenHashSet[Any](numItems + 1) - var i = 0 - while (i < numItems) { - val row = - new GenericInternalRow(rowSerializer.read( - kryo, - input, - classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) - set.add(row) - i += 1 - } - set - } -} - -private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] { - def write(kryo: Kryo, output: Output, hs: IntegerHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value: Int = iterator.next() - output.writeInt(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = { - val numItems = input.readInt() - val set = new IntegerHashSet - var i = 0 - while (i < numItems) { - val value = input.readInt() - set.add(value) - i += 1 - } - set - } -} - -private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { - def write(kryo: Kryo, output: Output, hs: LongHashSet) { - output.writeInt(hs.size) - val iterator = hs.iterator - while(iterator.hasNext) { - val value = iterator.next() - output.writeLong(value) - } - } - - def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = { - val numItems = input.readInt() - val set = new LongHashSet - var i = 0 - while (i < numItems) { - val value = input.readLong() - set.add(value) - i += 1 - } - set - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dd3bb33c57287..3d4ce633c07c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -73,10 +73,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side * of the join will be broadcasted and the other side will be streamed, with no shuffling * performed. If both sides of the join are eligible to be broadcasted then the - * - Sort merge: if the matching join keys are sortable and - * [[org.apache.spark.sql.SQLConf.SORTMERGE_JOIN]] is enabled (default), then sort merge join - * will be used. - * - Hash: will be chosen if neither of the above optimizations apply to this join. + * - Sort merge: if the matching join keys are sortable. */ object EquiJoinSelection extends Strategy with PredicateHelper { @@ -103,22 +100,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + if RowOrdering.isOrderable(leftKeys) => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => - val buildSide = - if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { - joins.BuildRight - } else { - joins.BuildLeft - } - val hashJoin = joins.ShuffledHashJoin( - leftKeys, rightKeys, buildSide, planLater(left), planLater(right)) - condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil - // --- Outer joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys( @@ -132,162 +118,114 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) => + if RowOrdering.isOrderable(leftKeys) => joins.SortMergeOuterJoin( leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) => - joins.ShuffledHashOuterJoin( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil - // --- Cases where this strategy does not apply --------------------------------------------- case _ => Nil } } - object HashAggregation extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Aggregations that can be performed in two phases, before and after the shuffle. - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) if !canBeConvertedToNewAggregation(plan) => - execution.Aggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.Aggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil - - case _ => Nil - } - - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { - case a: logical.Aggregate => - if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { - a.newAggregation.isDefined - } else { - Utils.checkInvalidAggregateFunction2(a) - false - } - case _ => false - } - - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = - exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) - } - /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && - sqlContext.conf.codegenEnabled => - val converted = p.newAggregation - converted match { - case None => Nil // Cannot convert to new aggregation code path. - case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap - - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets (aggregate.NewAggregation will not match). - sys.error( - "Multiple distinct column sets are not supported by the new aggregation" + - "code path.") - } + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionToAttribute = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction, agg.isDistinct) -> attribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") + } - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression2(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case AggregateExpression(aggregateFunction, _, isDistinct) => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + aggregateFunctionToAttribute(aggregateFunction, isDistinct) + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + val aggregateOperator = + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + if (functionsWithDistinct.nonEmpty) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } else { + aggregate.Utils.planAggregateWithoutPartial( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) } + } else if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + namedGroupingExpressions.map(_._2), + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } - val aggregateOperator = - if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { - if (functionsWithDistinct.nonEmpty) { - sys.error("Distinct columns cannot exist in Aggregate operator containing " + - "aggregate functions which don't support partial aggregation.") - } else { - aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - } else if (functionsWithDistinct.isEmpty) { - aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } else { - aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), - functionsWithDistinct, - functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - - aggregateOperator - } + aggregateOperator case _ => Nil } @@ -364,21 +302,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions - /** - * Picks an appropriate sort operator. - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ - def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - TungstenSort.supportsSchema(child.schema)) { - execution.TungstenSort(sortExprs, global, child) - } else { - execution.Sort(sortExprs, global, child) - } - } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil @@ -388,7 +311,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, tEnc, uEnc, output, child) => execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil - case logical.AppendColumn(f, tEnc, uEnc, newCol, child) => + case logical.AppendColumns(f, tEnc, uEnc, newCol, child) => execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil @@ -406,34 +329,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. - getSortOperator(sortExprs, global = false, planLater(child)) :: Nil + execution.Sort(sortExprs, global = false, child = planLater(child)) :: Nil case logical.Sort(sortExprs, global, child) => - getSortOperator(sortExprs, global, planLater(child)):: Nil + execution.Sort(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => - // If unsafe mode is enabled and we support these data types in Unsafe, use the - // Tungsten project. Otherwise, use the normal project. - if (sqlContext.conf.unsafeEnabled && - UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { - execution.TungstenProject(projectList, planLater(child)) :: Nil - } else { - execution.Project(projectList, planLater(child)) :: Nil - } + execution.Project(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled - if (useNewAggregation && a.newAggregation.isDefined) { - // If this logical.Aggregate can be planned to use new aggregation code path - // (i.e. it can be planned by the Strategy Aggregation), we will not use the old - // aggregation code path. - Nil - } else { - Utils.checkInvalidAggregateFunction2(a) - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - } - } case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => execution.Window( projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 53c5ccf8fa37e..b1280c32a6a43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -247,11 +247,7 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = if (child.outputsUnsafeRows) { - UnsafeProjection.create(partitionSpec, child.output) - } else { - newProjection(partitionSpec, child.output) - } + val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 99fb7a40b72e1..008478a6a0e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -35,9 +35,9 @@ import scala.collection.mutable.ArrayBuffer abstract class AggregationIterator( groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -76,14 +76,14 @@ abstract class AggregationIterator( // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction2] = { + protected val allAggregateFunctions: Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + val functions = new Array[AggregateFunction](allAggregateExpressions.length) var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match { + val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of @@ -135,7 +135,7 @@ abstract class AggregationIterator( } // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) // All imperative aggregate functions with mode Partial, PartialMerge, or Final. @@ -172,7 +172,7 @@ abstract class AggregationIterator( case (Some(Partial), None) => val updateExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val expressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -204,7 +204,7 @@ abstract class AggregationIterator( // allAggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } // This projection is used to merge buffer values for all expression-based aggregates. val expressionAggMergeProjection = @@ -225,7 +225,7 @@ abstract class AggregationIterator( // Final-Complete case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -248,7 +248,7 @@ abstract class AggregationIterator( val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } ++ completeOffsetExpressions val finalExpressionAggMergeProjection = newMutableProjection(mergeExpressions, mergeInputSchema)() @@ -256,7 +256,7 @@ abstract class AggregationIterator( val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -282,7 +282,7 @@ abstract class AggregationIterator( // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -291,7 +291,7 @@ abstract class AggregationIterator( val updateExpressions = completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -353,7 +353,7 @@ abstract class AggregationIterator( allAggregateFunctions.flatMap(_.aggBufferAttributes) val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp + case agg: AggregateFunction => NoOp } val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 4d37106e007f5..ee982453c3287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -69,7 +69,7 @@ case class SortBasedAggregate( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext @@ -78,11 +78,9 @@ case class SortBasedAggregate( // so return an empty iterator. Iterator[InternalRow]() } else { - val groupingKeyProjection = if (UnsafeProjection.canSupport(groupingExpressions)) { + val groupingKeyProjection = UnsafeProjection.create(groupingExpressions, child.output) - } else { - newMutableProjection(groupingExpressions, child.output)() - } + val outputIter = new SortBasedAggregationIterator( groupingKeyProjection, groupingExpressions.map(_.toAttribute), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 64c673064f576..fe5c3195f867b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.execution.metric.LongSQLMetric /** - * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been + * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been * sorted by values of [[groupingKeyAttributes]]. */ class SortBasedAggregationIterator( @@ -31,9 +31,9 @@ class SortBasedAggregationIterator( groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 15616915f7364..920de615e1d86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} @@ -30,9 +30,9 @@ import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -139,7 +139,6 @@ object TungstenAggregate { groupingExpressions: Seq[Expression], aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeProjection.canSupport(groupingExpressions) + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index ce8d592c368ee..04391443920ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -64,12 +64,12 @@ import org.apache.spark.sql.types.StructType * @param groupingExpressions * expressions for grouping keys * @param nonCompleteAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], - * [[PartialMerge]], or [[Final]]. + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' * outputs when they are stored in the final aggregation buffer. * @param completeAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]]. * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs * when they are stored in the final aggregation buffer. * @param resultExpressions @@ -83,9 +83,9 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -106,7 +106,7 @@ class TungstenAggregationIterator( // A Seq containing all AggregateExpressions. // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression2] = + private[this] val allAggregateExpressions: Seq[AggregateExpression] = nonCompleteAggregateExpressions ++ completeAggregateExpressions // Check to make sure we do not have more than three modes in our AggregateExpressions. @@ -150,10 +150,10 @@ class TungstenAggregationIterator( // Initialize all AggregateFunctions by binding references, if necessary, // and setting inputBufferOffset and mutableBufferOffset. private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction2] = { + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + val functions = new Array[AggregateFunction](allAggregateExpressions.length) var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction @@ -195,7 +195,7 @@ class TungstenAggregationIterator( functions } - private[this] var allAggregateFunctions: Array[AggregateFunction2] = + private[this] var allAggregateFunctions: Array[AggregateFunction] = initializeAllAggregateFunctions(initialInputBufferOffset) // Positions of those imperative aggregate functions in allAggregateFunctions. @@ -263,7 +263,7 @@ class TungstenAggregationIterator( case (Some(Partial), None) => val updateExpressions = allAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val imperativeAggregateFunctions: Array[ImperativeAggregate] = allAggregateFunctions.collect { case func: ImperativeAggregate => func} @@ -286,7 +286,7 @@ class TungstenAggregationIterator( case (Some(PartialMerge), None) | (Some(Final), None) => val mergeExpressions = allAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val imperativeAggregateFunctions: Array[ImperativeAggregate] = allAggregateFunctions.collect { case func: ImperativeAggregate => func} @@ -307,11 +307,11 @@ class TungstenAggregationIterator( // Final-Complete case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + val nonCompleteAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } @@ -321,7 +321,7 @@ class TungstenAggregationIterator( val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } ++ completeOffsetExpressions val finalMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -331,7 +331,7 @@ class TungstenAggregationIterator( Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -358,7 +358,7 @@ class TungstenAggregationIterator( // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -366,7 +366,7 @@ class TungstenAggregationIterator( val updateExpressions = completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -414,7 +414,7 @@ class TungstenAggregationIterator( val joinedRow = new JoinedRow() val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp + case agg: AggregateFunction => NoOp } val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() // These are the attributes of the row produced by `expressionAggEvalProjection` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala new file mode 100644 index 0000000000000..3f2775896bb8c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +object TypedAggregateExpression { + def apply[A, B : Encoder, C : Encoder]( + aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + new TypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], + encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], + Nil, + 0, + 0) + } +} + +/** + * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has + * the following limitations: + * - It assumes the aggregator reduces and returns a single column of type `long`. + * - It might only work when there is a single aggregator in the first column. + * - It assumes the aggregator has a zero, `0`. + */ +case class TypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + aEncoder: Option[ExpressionEncoder[Any]], + bEncoder: ExpressionEncoder[Any], + cEncoder: ExpressionEncoder[Any], + children: Seq[Attribute], + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int) + extends ImperativeAggregate with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + override def dataType: DataType = if (cEncoder.flat) { + cEncoder.schema.head.dataType + } else { + cEncoder.schema + } + + override def deterministic: Boolean = true + + override lazy val resolved: Boolean = aEncoder.isDefined + + override lazy val inputTypes: Seq[DataType] = Nil + + override val aggBufferSchema: StructType = bEncoder.schema + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + // We let the dataset do the binding for us. + lazy val boundA = aEncoder.get + + val bAttributes = bEncoder.schema.toAttributes + lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + + private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { + // todo: need a more neat way to assign the value. + var i = 0 + while (i < aggBufferAttributes.length) { + aggBufferSchema(i).dataType match { + case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) + case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + } + i += 1 + } + } + + override def initialize(buffer: MutableRow): Unit = { + val zero = bEncoder.toRow(aggregator.zero) + updateBuffer(buffer, zero) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val inputA = boundA.fromRow(input) + val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val merged = aggregator.reduce(currentB, inputA) + val returned = boundB.toRow(merged) + + updateBuffer(buffer, returned) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val merged = aggregator.merge(b1, b2) + val returned = boundB.toRow(merged) + + updateBuffer(buffer1, returned) + } + + override def eval(buffer: InternalRow): Any = { + val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val result = cEncoder.toRow(aggregator.finish(b)) + dataType match { + case _: StructType => result + case _ => result.get(0, dataType) + } + } + + override def toString: String = { + s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = aggregator.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d2f56e0fc14a4..20359c1e540e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index eaafd83158a15..a70e41436c7aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -28,8 +28,8 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -54,18 +54,15 @@ object Utils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. - val usesTungstenAggregate = - child.sqlContext.conf.unsafeEnabled && - TungstenAggregate.supportsAggregate( + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( groupingExpressions, aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - // 1. Create an Aggregate Operator for partial aggregations. val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -137,18 +134,16 @@ object Utils { def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], - functionsWithDistinct: Seq[AggregateExpression2], - functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + functionsWithDistinct: Seq[AggregateExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct - val usesTungstenAggregate = - child.sqlContext.conf.unsafeEnabled && - TungstenAggregate.supportsAggregate( - groupingExpressions, - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one // DISTINCT aggregate function, all of those functions will have the same column expression. @@ -253,16 +248,16 @@ object Utils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, true) => + case agg @ AggregateExpression(aggregateFunction, mode, true) => val rewrittenAggregateFunction = aggregateFunction.transformDown { case expr if expr == distinctColumnExpression => distinctColumnAttribute - }.asInstanceOf[AggregateFunction2] + }.asInstanceOf[AggregateFunction] // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true) + AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true) val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) (rewrittenAggregateExpression, aggregateFunctionAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 799650a4f784f..e79092efdaa3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -31,32 +31,6 @@ import org.apache.spark.{HashPartitioner, SparkEnv} case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList.map(_.toAttribute) - - override private[sql] lazy val metrics = Map( - "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - - @transient lazy val buildProjection = newMutableProjection(projectList, child.output) - - protected override def doExecute(): RDD[InternalRow] = { - val numRows = longMetric("numRows") - child.execute().mapPartitions { iter => - val reusableProjection = buildProjection() - iter.map { row => - numRows += 1 - reusableProjection(row) - } - } - } - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering -} - - -/** - * A variant of [[Project]] that returns [[UnsafeRow]]s. - */ -case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) @@ -69,8 +43,9 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") - child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(projectList, child.output) + child.execute().mapPartitionsInternal { iter => + val project = UnsafeProjection.create(projectList, child.output, + subexpressionEliminationEnabled) iter.map { row => numRows += 1 project(row) @@ -92,7 +67,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { val numInputRows = longMetric("numInputRows") val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val predicate = newPredicate(condition, child.output) iter.filter { row => numInputRows += 1 @@ -186,11 +161,11 @@ case class Limit(limit: Int, child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => iter.take(limit).map(row => (false, row.copy())) } } else { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val mutablePair = new MutablePair[Boolean, InternalRow]() iter.take(limit).map(row => mutablePair.update(false, row)) } @@ -198,7 +173,7 @@ case class Limit(limit: Int, child: SparkPlan) val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) - shuffled.mapPartitions(_.take(limit).map(_._2)) + shuffled.mapPartitionsInternal(_.take(limit).map(_._2)) } } @@ -319,7 +294,7 @@ case class MapPartitions[T, U]( child: SparkPlan) extends UnaryNode { override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) } @@ -336,10 +311,14 @@ case class AppendColumns[T, U]( newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { + // We are using an unsafe combiner. + override def canProcessSafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def output: Seq[Attribute] = child.output ++ newColumns override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val tBoundEncoder = tEncoder.bind(child.output) val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) iter.map { row => @@ -356,7 +335,7 @@ case class AppendColumns[T, U]( * being output. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], @@ -371,14 +350,15 @@ case class MapGroups[K, T, U]( Seq(groupingAttributes.map(SortOrder(_, Ascending))) override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) val groupKeyEncoder = kEncoder.bind(groupingAttributes) + val groupDataEncoder = tEncoder.bind(child.output) grouped.flatMap { case (key, rowIter) => val result = func( groupKeyEncoder.fromRow(key), - rowIter.map(tEncoder.fromRow)) + rowIter.map(groupDataEncoder.fromRow)) result.map(uEncoder.toRow) } } @@ -391,7 +371,7 @@ case class MapGroups[K, T, U]( * The result of this function is encoded and flattened before being output. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index e5f60b15e7359..24a79f289aa81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -111,6 +111,54 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated and " + + s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " + + s"continue to be true.") + Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.TUNGSTEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.TUNGSTEN_ENABLED} is deprecated and " + + s"will be ignored. Tungsten will continue to be used.") + Seq(Row(SQLConf.Deprecated.TUNGSTEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.CODEGEN_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.CODEGEN_ENABLED} is deprecated and " + + s"will be ignored. Codegen will continue to be used.") + Seq(Row(SQLConf.Deprecated.CODEGEN_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.UNSAFE_ENABLED, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.UNSAFE_ENABLED} is deprecated and " + + s"will be ignored. Unsafe mode will continue to be used.") + Seq(Row(SQLConf.Deprecated.UNSAFE_ENABLED, "true")) + } + (keyValueOutput, runFunc) + + (keyValueOutput, runFunc) + + case Some((SQLConf.Deprecated.SORTMERGE_JOIN, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.SORTMERGE_JOIN} is deprecated and " + + s"will be ignored. Sort merge join will continue to be used.") + Seq(Row(SQLConf.Deprecated.SORTMERGE_JOIN, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7265d6a4de2e6..544d5eccec037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -315,6 +315,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + val pushedFiltersString = pushedFilters.mkString(" PushedFilter: [", ",", "] ") + if (projects.map(_.toAttribute) == projects && projectSet.size == projects.size && filterSet.subsetOf(projectSet)) { @@ -332,7 +334,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation) + relation.relation, pushedFiltersString) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. @@ -342,8 +344,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.PhysicalRDD.createFromDataSource( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation) - execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) + relation.relation, pushedFiltersString) + execution.Project( + projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } @@ -453,8 +456,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * * @return A pair of `Seq[Expression]` and `Seq[Filter]`. The first element contains all Catalyst * predicate [[Expression]]s that are either not convertible or cannot be handled by - * `relation`. The second element contains all converted data source [[Filter]]s that can - * be handled by `relation`. + * `relation`. The second element contains all converted data source [[Filter]]s that + * will be pushed down to the data source. */ protected[sql] def selectFilters( relation: BaseRelation, @@ -476,7 +479,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Catalyst predicate expressions that cannot be translated to data source filters. val unrecognizedPredicates = predicates.filterNot(translatedMap.contains) - // Data source filters that cannot be handled by `relation` + // Data source filters that cannot be handled by `relation`. The semantic of a unhandled filter + // at here is that a data source may not be able to apply this filter to every row + // of the underlying dataset. val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet val (unhandled, handled) = translated.partition { @@ -491,6 +496,11 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Translated data source filters that can be handled by `relation` val (_, handledFilters) = handled.unzip - (unrecognizedPredicates ++ unhandledPredicates, handledFilters) + // translated contains all filters that have been converted to the public Filter interface. + // We should always push them to the data source no matter whether the data source can apply + // a filter to every row or not. + val (_, translatedFilters) = translated.unzip + + (unrecognizedPredicates ++ unhandledPredicates, translatedFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 86bc3a1b6dab2..81962f8d63789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -75,10 +75,11 @@ private[sql] object PartitioningUtils { private[sql] def parsePartitions( paths: Seq[Path], defaultPartitionName: String, - typeInference: Boolean): PartitionSpec = { + typeInference: Boolean, + basePaths: Set[Path]): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. - val (partitionValues, optBasePaths) = paths.map { path => - parsePartition(path, defaultPartitionName, typeInference) + val (partitionValues, optDiscoveredBasePaths) = paths.map { path => + parsePartition(path, defaultPartitionName, typeInference, basePaths) }.unzip // We create pairs of (path -> path's partition value) here @@ -101,11 +102,15 @@ private[sql] object PartitioningUtils { // It will be recognised as conflicting directory structure: // "hdfs://host:9000/invalidPath" // "hdfs://host:9000/path" - val basePaths = optBasePaths.flatMap(x => x) + val disvoeredBasePaths = optDiscoveredBasePaths.flatMap(x => x) assert( - basePaths.distinct.size == 1, + disvoeredBasePaths.distinct.size == 1, "Conflicting directory structures detected. Suspicious paths:\b" + - basePaths.distinct.mkString("\n\t", "\n\t", "\n\n")) + disvoeredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") + + "If provided paths are partition directories, please set " + + "\"basePath\" in the options of the data source to specify the " + + "root directory of the table. If there are multiple root directories, " + + "please load them separately and then union them.") val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) @@ -131,7 +136,7 @@ private[sql] object PartitioningUtils { /** * Parses a single partition, returns column names and values of each partition column, also - * the base path. For example, given: + * the path when we stop partition discovery. For example, given: * {{{ * path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14 * }}} @@ -144,40 +149,63 @@ private[sql] object PartitioningUtils { * Literal.create("hello", StringType), * Literal.create(3.14, FloatType))) * }}} - * and the base path: + * and the path when we stop the discovery is: * {{{ - * /path/to/partition + * hdfs://:/path/to/partition * }}} */ private[sql] def parsePartition( path: Path, defaultPartitionName: String, - typeInference: Boolean): (Option[PartitionValues], Option[Path]) = { + typeInference: Boolean, + basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null - var chopped = path - var basePath = path + // currentPath is the current path that we will use to parse partition column value. + var currentPath: Path = path while (!finished) { // Sometimes (e.g., when speculative task is enabled), temporary directories may be left - // uncleaned. Here we simply ignore them. - if (chopped.getName.toLowerCase == "_temporary") { + // uncleaned. Here we simply ignore them. + if (currentPath.getName.toLowerCase == "_temporary") { return (None, None) } - val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName, typeInference) - maybeColumn.foreach(columns += _) - basePath = chopped - chopped = chopped.getParent - finished = (maybeColumn.isEmpty && !columns.isEmpty) || chopped.getParent == null + if (basePaths.contains(currentPath)) { + // If the currentPath is one of base paths. We should stop. + finished = true + } else { + // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. + // Once we get the string, we try to parse it and find the partition column and value. + val maybeColumn = + parsePartitionColumn(currentPath.getName, defaultPartitionName, typeInference) + maybeColumn.foreach(columns += _) + + // Now, we determine if we should stop. + // When we hit any of the following cases, we will stop: + // - In this iteration, we could not parse the value of partition column and value, + // i.e. maybeColumn is None, and columns is not empty. At here we check if columns is + // empty to handle cases like /table/a=1/_temporary/something (we need to find a=1 in + // this case). + // - After we get the new currentPath, this new currentPath represent the top level dir + // i.e. currentPath.getParent == null. For the example of "/table/a=1/", + // the top level dir is "/table". + finished = + (maybeColumn.isEmpty && !columns.isEmpty) || currentPath.getParent == null + + if (!finished) { + // For the above example, currentPath will be "/table/". + currentPath = currentPath.getParent + } + } } if (columns.isEmpty) { (None, Some(path)) } else { val (columnNames, values) = columns.reverse.unzip - (Some(PartitionValues(columnNames, values)), Some(basePath)) + (Some(PartitionValues(columnNames, values)), Some(currentPath)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala index 6773afc794f9c..f522303be94ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.execution.datasources.jdbc diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index b9914c581a657..922fd5b21167b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -25,33 +25,36 @@ import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -private[sql] object InferSchema { + +private[json] object InferSchema { + /** * Infer the type of a collection of json records in three stages: * 1. Infer the type of each record * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def apply( + def infer( json: RDD[String], - samplingRatio: Double = 1.0, columnNameOfCorruptRecords: String, - primitivesAsString: Boolean = false): StructType = { - require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") - val schemaData = if (samplingRatio > 0.99) { + configOptions: JSONOptions): StructType = { + require(configOptions.samplingRatio > 0, + s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") + val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { - json.sample(withReplacement = false, samplingRatio, 1) + json.sample(withReplacement = false, configOptions.samplingRatio, 1) } // perform schema inference on each row and merge afterwards val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) iter.map { row => try { Utils.tryWithResource(factory.createParser(row)) { parser => parser.nextToken() - inferField(parser, primitivesAsString) + inferField(parser, configOptions) } } catch { case _: JsonParseException => @@ -71,14 +74,14 @@ private[sql] object InferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField(parser: JsonParser, primitivesAsString: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, primitivesAsString) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -95,7 +98,7 @@ private[sql] object InferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, primitivesAsString), + inferField(parser, configOptions), nullable = true) } @@ -107,14 +110,15 @@ private[sql] object InferSchema { // the type as we pass through all JSON objects. var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType(elementType, inferField(parser, primitivesAsString)) + elementType = compatibleType( + elementType, inferField(parser, configOptions)) } ArrayType(elementType) - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if primitivesAsString => StringType + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType - case (VALUE_TRUE | VALUE_FALSE) if primitivesAsString => StringType + case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => import JsonParser.NumberType._ @@ -178,7 +182,7 @@ private[sql] object InferSchema { /** * Returns the most general data type for two given data types. */ - private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { + def compatibleType(t1: DataType, t2: DataType): DataType = { HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala new file mode 100644 index 0000000000000..c132ead20e7d6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import com.fasterxml.jackson.core.{JsonParser, JsonFactory} + +/** + * Options for the JSON data source. + * + * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. + */ +case class JSONOptions( + samplingRatio: Double = 1.0, + primitivesAsString: Boolean = false, + allowComments: Boolean = false, + allowUnquotedFieldNames: Boolean = false, + allowSingleQuotes: Boolean = true, + allowNumericLeadingZeros: Boolean = false, + allowNonNumericNumbers: Boolean = false) { + + /** Sets config options on a Jackson [[JsonFactory]]. */ + def setJacksonOptions(factory: JsonFactory): Unit = { + factory.configure(JsonParser.Feature.ALLOW_COMMENTS, allowComments) + factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, allowUnquotedFieldNames) + factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) + factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros) + factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + } +} + + +object JSONOptions { + def createFromConfigMap(parameters: Map[String, String]): JSONOptions = JSONOptions( + samplingRatio = + parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0), + primitivesAsString = + parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false), + allowComments = + parameters.get("allowComments").map(_.toBoolean).getOrElse(false), + allowUnquotedFieldNames = + parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false), + allowSingleQuotes = + parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true), + allowNumericLeadingZeros = + parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), + allowNonNumericNumbers = + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 85b52f04c8d01..3e61ba35bea8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -52,29 +52,28 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { - val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) - val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) new JSONRelation( - None, - samplingRatio, - primitivesAsString, - dataSchema, - None, - partitionColumns, - paths)(sqlContext) + inputRDD = None, + maybeDataSchema = dataSchema, + maybePartitionSpec = None, + userDefinedPartitionColumns = partitionColumns, + paths = paths, + parameters = parameters)(sqlContext) } } private[sql] class JSONRelation( val inputRDD: Option[RDD[String]], - val samplingRatio: Double, - val primitivesAsString: Boolean, val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - override val paths: Array[String] = Array.empty[String])(@transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) { + override val paths: Array[String] = Array.empty[String], + parameters: Map[String, String] = Map.empty[String, String]) + (@transient val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec, parameters) { + + val options: JSONOptions = JSONOptions.createFromConfigMap(parameters) /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { @@ -106,17 +105,16 @@ private[sql] class JSONRelation( classOf[Text]).map(_._2.toString) // get the text line } - override lazy val dataSchema = { + override lazy val dataSchema: StructType = { val jsonSchema = maybeDataSchema.getOrElse { val files = cachedLeafStatuses().filterNot { status => val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.toArray - InferSchema( + InferSchema.infer( inputRDD.getOrElse(createBaseRdd(files)), - samplingRatio, sqlContext.conf.columnNameOfCorruptRecord, - primitivesAsString) + options) } checkConstraints(jsonSchema) @@ -129,10 +127,11 @@ private[sql] class JSONRelation( inputPaths: Array[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_))) - val rows = JacksonParser( + val rows = JacksonParser.parse( inputRDD.getOrElse(createBaseRdd(inputPaths)), requiredDataSchema, - sqlContext.conf.columnNameOfCorruptRecord) + sqlContext.conf.columnNameOfCorruptRecord, + options) rows.mapPartitions { iterator => val unsafeProjection = UnsafeProjection.create(requiredDataSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 4f53eeb081b93..bfa1405041058 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream +import scala.collection.mutable.ArrayBuffer import com.fasterxml.jackson.core._ -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -32,18 +31,23 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -private[sql] object JacksonParser { - def apply( - json: RDD[String], +object JacksonParser { + + def parse( + input: RDD[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { - parseJson(json, schema, columnNameOfCorruptRecords) + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): RDD[InternalRow] = { + + input.mapPartitions { iter => + parseJson(iter, schema, columnNameOfCorruptRecords, configOptions) + } } /** * Parse the current token (and related children) according to a desired schema */ - private[sql] def convertField( + def convertField( factory: JsonFactory, parser: JsonParser, schema: DataType): Any = { @@ -226,9 +230,10 @@ private[sql] object JacksonParser { } private def parseJson( - json: RDD[String], + input: Iterator[String], schema: StructType, - columnNameOfCorruptRecords: String): RDD[InternalRow] = { + columnNameOfCorruptRecords: String, + configOptions: JSONOptions): Iterator[InternalRow] = { def failedRecord(record: String): Seq[InternalRow] = { // create a row even if no corrupt record column is present @@ -241,37 +246,36 @@ private[sql] object JacksonParser { Seq(row) } - json.mapPartitions { iter => - val factory = new JsonFactory() - - iter.flatMap { record => - if (record.trim.isEmpty) { - Nil - } else { - try { - Utils.tryWithResource(factory.createParser(record)) { parser => - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of " + - "the file (or each string in the RDD) is a valid JSON object or " + - "an array of JSON objects.") - } + val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) + + input.flatMap { record => + if (record.trim.isEmpty) { + Nil + } else { + try { + Utils.tryWithResource(factory.createParser(record)) { parser => + parser.nextToken() + + convertField(factory, parser, schema) match { + case null => failedRecord(record) + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema) + } + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of " + + "the file (or each string in the RDD) is a valid JSON object or " + + "an array of JSON objects.") } - } catch { - case _: JsonProcessingException => - failedRecord(record) } + } catch { + case _: JsonProcessingException => + failedRecord(record) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 7f3394c20ed3d..5f9f9083098a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -108,6 +108,9 @@ private[parquet] class CatalystSchemaConverter( def typeString = if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + def typeNotSupported() = + throw new AnalysisException(s"Parquet type not supported: $typeString") + def typeNotImplemented() = throw new AnalysisException(s"Parquet type not yet supported: $typeString") @@ -142,6 +145,9 @@ private[parquet] class CatalystSchemaConverter( case INT_32 | null => IntegerType case DATE => DateType case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case UINT_8 => typeNotSupported() + case UINT_16 => typeNotSupported() + case UINT_32 => typeNotSupported() case TIME_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -150,6 +156,7 @@ private[parquet] class CatalystSchemaConverter( originalType match { case INT_64 | null => LongType case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() } @@ -163,9 +170,10 @@ private[parquet] class CatalystSchemaConverter( case BINARY => originalType match { - case UTF8 | ENUM => StringType + case UTF8 | ENUM | JSON => StringType case null if assumeBinaryIsString => StringType case null => BinaryType + case BSON => BinaryType case DECIMAL => makeDecimalType() case _ => illegalType() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 483363d2c1a21..6862dea5e6c3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -429,7 +429,7 @@ private[parquet] object CatalystWriteSupport { def setSchema(schema: StructType, configuration: Configuration): Unit = { schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) configuration.set(SPARK_ROW_SCHEMA, schema.json) - configuration.set( + configuration.setIfUnset( ParquetOutputFormat.WRITER_VERSION, ParquetProperties.WriterVersion.PARQUET_1_0.toString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 5a7c6b95b565f..cb0aab8cc0d09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -109,7 +109,7 @@ private[sql] class ParquetRelation( override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) + extends HadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( @@ -383,7 +383,7 @@ private[sql] class ParquetRelation( var schema: StructType = _ // Cached leaves - var cachedLeaves: Set[FileStatus] = null + var cachedLeaves: mutable.LinkedHashSet[FileStatus] = null /** * Refreshes `FileStatus`es, footers, partition spec, and table schema. @@ -396,13 +396,13 @@ private[sql] class ParquetRelation( !cachedLeaves.equals(currentLeafStatuses) if (leafStatusesChanged) { - cachedLeaves = currentLeafStatuses.toIterator.toSet + cachedLeaves = currentLeafStatuses // Lists `FileStatus`es of all leaf nodes (files) under all base directories. val leaves = currentLeafStatuses.filter { f => isSummaryFile(f.getPath) || !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - }.toArray + }.toArray.sortBy(_.getPath.toString) dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) metadataStatuses = @@ -465,13 +465,30 @@ private[sql] class ParquetRelation( // You should enable this configuration only if you are very sure that for the parquet // part-files to read there are corresponding summary files containing correct schema. + // As filed in SPARK-11500, the order of files to touch is a matter, which might affect + // the ordering of the output columns. There are several things to mention here. + // + // 1. If mergeRespectSummaries config is false, then it merges schemas by reducing from + // the first part-file so that the columns of the lexicographically first file show + // first. + // + // 2. If mergeRespectSummaries config is true, then there should be, at least, + // "_metadata"s for all given files, so that we can ensure the columns of + // the lexicographically first file show first. + // + // 3. If shouldMergeSchemas is false, but when multiple files are given, there is + // no guarantee of the output order, since there might not be a summary file for the + // lexicographically first file, which ends up putting ahead the columns of + // the other files. However, this should be okay since not enabling + // shouldMergeSchemas means (assumes) all the files have the same schemas. + val needMerged: Seq[FileStatus] = if (mergeRespectSummaries) { Seq() } else { dataStatuses } - (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq + needMerged ++ metadataStatuses ++ commonMetadataStatuses } else { // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet // don't have this. @@ -768,10 +785,10 @@ private[sql] object ParquetRelation extends Logging { footers.map { footer => ParquetRelation.readSchemaFromFooter(footer, converter) - }.reduceOption(_ merge _).iterator + }.reduceLeftOption(_ merge _).iterator }.collect() - partiallyMergedSchemas.reduceOption(_ merge _) + partiallyMergedSchemas.reduceLeftOption(_ merge _) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 4b8b8e4e74dad..fbd387bc2ef47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -71,9 +71,10 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { private[sql] class TextRelation( val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], - override val paths: Array[String] = Array.empty[String]) + override val paths: Array[String] = Array.empty[String], + parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) { + extends HadoopFsRelation(maybePartitionSpec, parameters) { /** Data schema is always a single column, named "text". */ override def dataSchema: StructType = new StructType().add("value", StringType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index c5cd6a2fd6372..004407b2e6925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -54,7 +54,7 @@ case class BroadcastLeftSemiJoinHash( val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric) val broadcastedRelation = sparkContext.broadcast(hashSet) - left.execute().mapPartitions { streamIter => + left.execute().mapPartitionsInternal { streamIter => hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows) } } else { @@ -62,7 +62,7 @@ case class BroadcastLeftSemiJoinHash( HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) - left.execute().mapPartitions { streamIter => + left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value hashedRelation match { case unsafe: UnsafeHashedRelation => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 05d20f511aef8..aab177b2e8427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoin( @@ -95,9 +95,7 @@ case class BroadcastNestedLoopJoin( /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new CompactBuffer[InternalRow] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow val leftNulls = new GenericMutableRow(left.output.size) @@ -115,11 +113,11 @@ case class BroadcastNestedLoopJoin( case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case _ => } i += 1 @@ -138,8 +136,8 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - )(_ ++ _) + new BitSet(broadcastedRelation.value.size) + )(_ | _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -155,7 +153,7 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withLeft(leftNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { + if (!allIncludedBroadcastTuples.get(i)) { buf += resultProj(joinedRow.withRight(rel(i))).copy() } i += 1 @@ -164,7 +162,7 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withRight(rightNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { + if (!allIncludedBroadcastTuples.get(i)) { buf += resultProj(joinedRow.withLeft(rel(i))).copy() } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index 0243e196dbc37..f467519b802a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -46,7 +46,7 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod row.copy() } - leftResults.cartesian(rightResults).mapPartitions { iter => + leftResults.cartesian(rightResults).mapPartitionsInternal { iter => val joinedRow = new JoinedRow iter.map { r => numOutputRows += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7ce4a517838cb..fb961d97c3c3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,29 +44,15 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false protected def buildSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } + UnsafeProjection.create(buildKeys, buildPlan.output) protected def streamSideKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) - } else { - newMutableProjection(streamedKeys, streamedPlan.output)() - } + UnsafeProjection.create(streamedKeys, streamedPlan.output) protected def hashJoin( streamIter: Iterator[InternalRow], @@ -81,13 +67,8 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } - } + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(self.schema) private[this] val joinKeys = streamSideKeyGenerator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 15b06b1537f8c..ed626fef56af7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,38 +64,18 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - protected[this] def isUnsafeMode: Boolean = { - (self.codegenEnabled && self.unsafeEnabled && joinType != FullOuter - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(self.schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false protected def buildKeyGenerator: Projection = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildPlan.output) - } else { - newMutableProjection(buildKeys, buildPlan.output)() - } + UnsafeProjection.create(buildKeys, buildPlan.output) - protected[this] def streamedKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedPlan.output) - } else { - newProjection(streamedKeys, streamedPlan.output) - } - } + protected[this] def streamedKeyGenerator: Projection = + UnsafeProjection.create(streamedKeys, streamedPlan.output) - protected[this] def resultProjection: InternalRow => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(self.schema) - } else { - identity[InternalRow] - } - } + protected[this] def resultProjection: InternalRow => InternalRow = + UnsafeProjection.create(self.schema) @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -173,8 +153,12 @@ trait HashOuterJoin { } protected[this] def fullOuterIterator( - key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow], - joinedRow: JoinedRow, numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + key: InternalRow, + leftIter: Iterable[InternalRow], + rightIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy // the join condition. @@ -191,7 +175,7 @@ trait HashOuterJoin { matched = true // if the row satisfy the join condition, add its index into the matched set rightMatchedSet.add(idx) - joinedRow.copy() + resultProjection(joinedRow) } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. @@ -201,7 +185,7 @@ trait HashOuterJoin { // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. numOutputRows += 1 - joinedRow.withRight(rightNullRow).copy() + resultProjection(joinedRow.withRight(rightNullRow)) }) } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. @@ -210,15 +194,15 @@ trait HashOuterJoin { // in the matched set. case (r, idx) if !rightMatchedSet.contains(idx) => numOutputRows += 1 - joinedRow(leftNullRow, r).copy() + resultProjection(joinedRow(leftNullRow, r)) } } else { leftIter.iterator.map[InternalRow] { l => numOutputRows += 1 - joinedRow(l, rightNullRow).copy() + resultProjection(joinedRow(l, rightNullRow)) } ++ rightIter.iterator.map[InternalRow] { r => numOutputRows += 1 - joinedRow(leftNullRow, r).copy() + resultProjection(joinedRow(leftNullRow, r)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index beb141ade616d..f23a1830e91c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,31 +33,15 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - protected[this] def supportUnsafe: Boolean = { - (self.codegenEnabled && self.unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema) - && UnsafeProjection.canSupport(right.schema)) - } - - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe - override def canProcessSafeRows: Boolean = !supportUnsafe + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false protected def leftKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(leftKeys, left.output) - } else { - newMutableProjection(leftKeys, left.output)() - } + UnsafeProjection.create(leftKeys, left.output) protected def rightKeyGenerator: Projection = - if (supportUnsafe) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newMutableProjection(rightKeys, right.output)() - } + UnsafeProjection.create(rightKeys, right.output) @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala deleted file mode 100644 index 755986af8b95e..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Performs an inner hash join of two child relations by first shuffling the data using the join - * keys. - */ -case class ShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: SparkPlan, - right: SparkPlan) - extends BinaryNode with HashJoin { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def outputPartitioning: Partitioning = - PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val (numBuildRows, numStreamedRows) = buildSide match { - case BuildLeft => (longMetric("numLeftRows"), longMetric("numRightRows")) - case BuildRight => (longMetric("numRightRows"), longMetric("numLeftRows")) - } - val numOutputRows = longMetric("numOutputRows") - - buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, numBuildRows, buildSideKeyGenerator) - hashJoin(streamIter, numStreamedRows, hashed, numOutputRows) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala deleted file mode 100644 index 6b2cb9d8f6893..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.JavaConverters._ - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Performs a hash based outer join for two child relations by shuffling the data using - * the join keys. This operator requires loading the associated partition in both side into memory. - */ -case class ShuffledHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override private[sql] lazy val metrics = Map( - "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), - "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - - protected override def doExecute(): RDD[InternalRow] = { - val numLeftRows = longMetric("numLeftRows") - val numRightRows = longMetric("numRightRows") - val numOutputRows = longMetric("numOutputRows") - - val joinedRow = new JoinedRow() - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - // TODO this probably can be replaced by external sort (sort merged join?) - joinType match { - case LeftOuter => - val hashed = HashedRelation(rightIter, numRightRows, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - val resultProj = resultProjection - leftIter.flatMap( currentRow => { - numLeftRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - val hashed = HashedRelation(leftIter, numLeftRows, buildKeyGenerator) - val keyGenerator = streamedKeyGenerator - val resultProj = resultProjection - rightIter.flatMap ( currentRow => { - numRightRows += 1 - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case FullOuter => - // TODO(davies): use UnsafeRow - val leftHashTable = - buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)).asScala - val rightHashTable = - buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)).asScala - (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST), - joinedRow, - numOutputRows) - } - - case x => - throw new IllegalArgumentException( - s"ShuffledHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 17030947b7bbc..4bf7b521c77d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -53,16 +53,9 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - protected[this] def isUnsafeMode: Boolean = { - (codegenEnabled && unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. @@ -77,26 +70,10 @@ case class SortMergeJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { // The projection used to extract keys from input rows of the left child. - private[this] val leftKeyGenerator = { - if (isUnsafeMode) { - // It is very important to use UnsafeProjection if input rows are UnsafeRows. - // Otherwise, GenerateProjection will cause wrong results. - UnsafeProjection.create(leftKeys, left.output) - } else { - newProjection(leftKeys, left.output) - } - } + private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) // The projection used to extract keys from input rows of the right child. - private[this] val rightKeyGenerator = { - if (isUnsafeMode) { - // It is very important to use UnsafeProjection if input rows are UnsafeRows. - // Otherwise, GenerateProjection will cause wrong results. - UnsafeProjection.create(rightKeys, right.output) - } else { - newProjection(rightKeys, right.output) - } - } + private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output) // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -113,13 +90,8 @@ case class SortMergeJoin( numRightRows ) private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } + private[this] val resultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(schema) override def advanceNext(): Boolean = { if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index 7e854e6702f77..efaa69c1d3227 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -89,32 +89,15 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } - private def isUnsafeMode: Boolean = { - (codegenEnabled && unsafeEnabled - && UnsafeProjection.canSupport(leftKeys) - && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(schema)) - } - - override def outputsUnsafeRows: Boolean = isUnsafeMode - override def canProcessUnsafeRows: Boolean = isUnsafeMode - override def canProcessSafeRows: Boolean = !isUnsafeMode + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false - private def createLeftKeyGenerator(): Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(leftKeys, left.output) - } else { - newProjection(leftKeys, left.output) - } - } + private def createLeftKeyGenerator(): Projection = + UnsafeProjection.create(leftKeys, left.output) - private def createRightKeyGenerator(): Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(rightKeys, right.output) - } else { - newProjection(rightKeys, right.output) - } - } + private def createRightKeyGenerator(): Projection = + UnsafeProjection.create(rightKeys, right.output) override def doExecute(): RDD[InternalRow] = { val numLeftRows = longMetric("numLeftRows") @@ -131,13 +114,7 @@ case class SortMergeOuterJoin( (r: InternalRow) => true } } - val resultProj: InternalRow => InternalRow = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(schema) joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala index 52dcb9e43c4e8..3dcef94095647 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -50,11 +50,7 @@ case class BinaryHashJoinNode( private def buildSideKeyGenerator: Projection = { // We are expecting the data types of buildKeys and streamedKeys are the same. assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)) - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildNode.output) - } else { - newMutableProjection(buildKeys, buildNode.output)() - } + UnsafeProjection.create(buildKeys, buildNode.output) } protected override def doOpen(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index b1dc719ca8508..fd7948ffa9a9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -45,20 +45,8 @@ trait HashJoinNode { private[this] var hashed: HashedRelation = _ private[this] var joinKeys: Projection = _ - protected def isUnsafeMode: Boolean = { - (codegenEnabled && - unsafeEnabled && - UnsafeProjection.canSupport(schema) && - UnsafeProjection.canSupport(streamedKeys)) - } - - private def streamSideKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(streamedKeys, streamedNode.output) - } else { - newMutableProjection(streamedKeys, streamedNode.output)() - } - } + private def streamSideKeyGenerator: Projection = + UnsafeProjection.create(streamedKeys, streamedNode.output) /** * Sets the HashedRelation used by this node. This method needs to be called after @@ -76,13 +64,7 @@ trait HashJoinNode { override def open(): Unit = { doOpen() joinRow = new JoinedRow - resultProjection = { - if (isUnsafeMode) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } - } + resultProjection = UnsafeProjection.create(schema) joinKeys = streamSideKeyGenerator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index f96b62a67a254..d3381eac91d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -35,10 +35,6 @@ import org.apache.spark.sql.types.StructType */ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { - protected val codegenEnabled: Boolean = conf.codegenEnabled - - protected val unsafeEnabled: Boolean = conf.unsafeEnabled - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") /** @@ -111,21 +107,17 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } else { - new InterpretedProjection(expressions, inputSchema) + s"Creating Projection: $expressions, inputSchema: $inputSchema") + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } } } @@ -133,41 +125,33 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( - s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled) { - try { - GenerateMutableProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate mutable projection, fallback to interpreted", e) - () => new InterpretedMutableProjection(expressions, inputSchema) - } - } - } else { - () => new InterpretedMutableProjection(expressions, inputSchema) + s"Creating MutableProj: $expressions, inputSchema: $inputSchema") + try { + GenerateMutableProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate mutable projection, fallback to interpreted", e) + () => new InterpretedMutableProjection(expressions, inputSchema) + } } } protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled) { - try { - GeneratePredicate.generate(expression, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate predicate, fallback to interpreted", e) - InterpretedPredicate.create(expression, inputSchema) - } - } - } else { - InterpretedPredicate.create(expression, inputSchema) + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 0e601cd2cab5d..5f8fc2de8b46d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.catalyst.rules.Rule */ case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") - override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -97,18 +95,10 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows if all the schema of them are support by UnsafeRow - if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } + // convert everything unsafe rows. + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala new file mode 100644 index 0000000000000..360c9a5bc15e7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} + +/** + * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] + * operations to take all of the elements of a group and reduce them to a single value. + * + * For example, the following aggregator extracts an `int` from a specific class and adds them up: + * {{{ + * case class Data(i: Int) + * + * val customSummer = new Aggregator[Data, Int, Int] { + * def zero = 0 + * def reduce(b: Int, a: Data) = b + a.i + * def present(r: Int) = r + * }.toColumn() + * + * val ds: Dataset[Data] + * val aggregated = ds.select(customSummer) + * }}} + * + * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird + * + * @tparam A The input type for the aggregation. + * @tparam B The type of the intermediate value of the reduction. + * @tparam C The type of the final result. + */ +abstract class Aggregator[-A, B, C] { + + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + def zero: B + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + def reduce(b: B, a: A): B + + /** + * Merge two intermediate values + */ + def merge(b1: B, b2: B): B + + /** + * Transform the output of the reduction. + */ + def finish(reduction: B): C + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] + * operations. + */ + def toColumn( + implicit bEncoder: Encoder[B], + cEncoder: Encoder[C]): TypedColumn[A, C] = { + val expr = + new AggregateExpression( + TypedAggregateExpression(this), + Complete, + false) + + new TypedColumn[A, C](expr, encoderFor[C]) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 8b9247adea200..fc873c04f88f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ /** @@ -141,40 +141,56 @@ class WindowSpec private[sql]( */ private[sql] def withAggregate(aggregate: Column): Column = { val windowExpr = aggregate.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction( - "first_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction( - "last_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) + // First, we check if we get an aggregate function without the DISTINCT keyword. + // Right now, we do not support using a DISTINCT aggregate function as a + // window function. + case AggregateExpression(aggregateFunction, _, isDistinct) if !isDistinct => + aggregateFunction match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child, ignoreNulls) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction( + "first_value", + child :: ignoreNulls :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child, ignoreNulls) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction( + "last_value", + child :: ignoreNulls :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in a window operation.") + } + + case AggregateExpression(aggregateFunction, _, isDistinct) if isDistinct => + throw new UnsupportedOperationException( + s"Distinct aggregate function ${aggregateFunction} is not supported " + + s"in window operation.") + + case wf: WindowFunction => + WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => - throw new UnsupportedOperationException(s"$x is not supported in window operation.") + throw new UnsupportedOperationException(s"$x is not supported in a window operation.") } + new Column(windowExpr) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 258afadc76951..11dbf391cff98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.types._ @@ -109,7 +109,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { @scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = false) @@ -123,7 +123,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 04627589886a8..95158de710acf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql + + import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -24,11 +26,33 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.encoders.FlatEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +/** + * Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have + * legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate + * "bridge" methods due to the use of covariant return types. + * + * {{{ + * In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * + * In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); + * }}} + * + * This allows us to use the same functions both in typed [[Dataset]] operations and untyped + * [[DataFrame]] operations when the return type for a given function is statically known. + */ +private[sql] abstract class LegacyFunctions { + def count(columnName: String): Column +} + /** * :: Experimental :: * Functions available for [[DataFrame]]. @@ -48,11 +72,17 @@ import org.apache.spark.util.Utils */ @Experimental // scalastyle:off -object functions { +object functions extends LegacyFunctions { // scalastyle:on private def withExpr(expr: Expression): Column = Column(expr) + private def withAggregateFunction( + func: AggregateFunction, + isDistinct: Boolean = false): Column = { + Column(func.toAggregateExpression(isDistinct)) + } + /** * Returns a [[Column]] based on the given column name. * @@ -128,7 +158,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = withExpr { ApproxCountDistinct(e.expr) } + def approxCountDistinct(e: Column): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -144,8 +176,8 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = withExpr { - ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr, rsd, 0, 0) } /** @@ -164,7 +196,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = withExpr { Average(e.expr) } + def avg(e: Column): Column = withAggregateFunction { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -174,13 +206,33 @@ object functions { */ def avg(columnName: String): Column = avg(Column(columnName)) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(e: Column): Column = callUDF("collect_list", e) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(e: Column): Column = callUDF("collect_set", e) + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * * @group agg_funcs * @since 1.6.0 */ - def corr(column1: Column, column2: Column): Column = withExpr { + def corr(column1: Column, column2: Column): Column = withAggregateFunction { Corr(column1.expr, column2.expr) } @@ -200,7 +252,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = withExpr { + def count(e: Column): Column = withAggregateFunction { e.expr match { // Turn count(*) into count(1) case s: Star => Count(Literal(1)) @@ -214,7 +266,8 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(columnName: String): Column = count(Column(columnName)) + def count(columnName: String): TypedColumn[Any, Long] = + count(Column(columnName)).as(FlatEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group. @@ -223,8 +276,8 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = withExpr { - CountDistinct((expr +: exprs).map(_.expr)) + def countDistinct(expr: Column, exprs: Column*): Column = { + withAggregateFunction(Count.apply((expr +: exprs).map(_.expr)), isDistinct = true) } /** @@ -243,7 +296,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = withExpr { First(e.expr) } + def first(e: Column): Column = withAggregateFunction { new First(e.expr) } /** * Aggregate function: returns the first value of a column in a group. @@ -259,7 +312,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = withExpr { Kurtosis(e.expr) } + def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } /** * Aggregate function: returns the last value in a group. @@ -267,7 +320,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = withExpr { Last(e.expr) } + def last(e: Column): Column = withAggregateFunction { new Last(e.expr) } /** * Aggregate function: returns the last value of the column in a group. @@ -283,7 +336,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = withExpr { Max(e.expr) } + def max(e: Column): Column = withAggregateFunction { Max(e.expr) } /** * Aggregate function: returns the maximum value of the column in a group. @@ -317,7 +370,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = withExpr { Min(e.expr) } + def min(e: Column): Column = withAggregateFunction { Min(e.expr) } /** * Aggregate function: returns the minimum value of the column in a group. @@ -333,7 +386,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def skewness(e: Column): Column = withExpr { Skewness(e.expr) } + def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } /** * Aggregate function: alias for [[stddev_samp]]. @@ -341,16 +394,16 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = withExpr { StddevSamp(e.expr) } + def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** - * Aggregate function: returns the unbiased sample standard deviation of + * Aggregate function: returns the sample standard deviation of * the expression in a group. * * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = withExpr { StddevSamp(e.expr) } + def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: returns the population standard deviation of @@ -359,7 +412,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = withExpr { StddevPop(e.expr) } + def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } /** * Aggregate function: returns the sum of all values in the expression. @@ -367,7 +420,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = withExpr { Sum(e.expr) } + def sum(e: Column): Column = withAggregateFunction { Sum(e.expr) } /** * Aggregate function: returns the sum of all values in the given column. @@ -383,7 +436,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = withExpr { SumDistinct(e.expr) } + def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -399,7 +452,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = withExpr { VarianceSamp(e.expr) } + def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -407,7 +460,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = withExpr { VarianceSamp(e.expr) } + def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the population variance of the values in a group. @@ -415,7 +468,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(e: Column): Column = withExpr { VariancePop(e.expr) } + def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions @@ -2252,6 +2305,18 @@ object functions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Creates a new row for a json column according to the given field names. + * + * @group collection_funcs + * @since 1.6.0 + */ + @scala.annotation.varargs + def json_tuple(json: Column, fields: String*): Column = withExpr { + require(fields.length > 0, "at least 1 field name should be given.") + JsonTuple(json.expr +: fields.map(Literal.apply)) + } + /** * Returns length of array or map. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index e296d631f0f30..b3d3bdf50df63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -235,9 +235,11 @@ abstract class BaseRelation { def needConversion: Boolean = true /** - * Given an array of [[Filter]]s, returns an array of [[Filter]]s that this data source relation - * cannot handle. Spark SQL will apply all returned [[Filter]]s against rows returned by this - * data source relation. + * Returns the list of [[Filter]]s that this datasource may not be able to handle. + * These returned [[Filter]]s will be evaluated by Spark SQL after data is output by a scan. + * By default, this function will return all filters, as it is always safe to + * double evaluate a [[Filter]]. However, specific implementations can override this function to + * avoid double filtering when they are capable of processing a filter internally. * * @since 1.6.0 */ @@ -414,25 +416,30 @@ abstract class OutputWriter { * @since 1.4.0 */ @Experimental -abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) +abstract class HadoopFsRelation private[sql]( + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String]) extends BaseRelation with FileRelation with Logging { override def toString: String = getClass.getSimpleName + paths.mkString("[", ",", "]") - def this() = this(None) + def this() = this(None, Map.empty[String, String]) - private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + def this(parameters: Map[String, String]) = this(None, parameters) + + private[sql] def this(maybePartitionSpec: Option[PartitionSpec]) = + this(maybePartitionSpec, Map.empty[String, String]) - private val codegenEnabled = sqlContext.conf.codegenEnabled + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) private var _partitionSpec: PartitionSpec = _ private class FileStatusCache { - var leafFiles = mutable.Map.empty[Path, FileStatus] + var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] - private def listLeafFiles(paths: Array[String]): Set[FileStatus] = { + private def listLeafFiles(paths: Array[String]): mutable.LinkedHashSet[FileStatus] = { if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) } else { @@ -450,10 +457,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio val (dirs, files) = statuses.partition(_.isDir) + // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) if (dirs.isEmpty) { - files.toSet + mutable.LinkedHashSet(files: _*) } else { - files.toSet ++ listLeafFiles(dirs.map(_.getPath.toString)) + mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath.toString)) } } } @@ -464,7 +472,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio leafFiles.clear() leafDirToChildrenFiles.clear() - leafFiles ++= files.map(f => f.getPath -> f).toMap + leafFiles ++= files.map(f => f.getPath -> f) leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) } } @@ -475,8 +483,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio cache } - protected def cachedLeafStatuses(): Set[FileStatus] = { - fileStatusCache.leafFiles.values.toSet + protected def cachedLeafStatuses(): mutable.LinkedHashSet[FileStatus] = { + mutable.LinkedHashSet(fileStatusCache.leafFiles.values.toArray: _*) } final private[sql] def partitionSpec: PartitionSpec = { @@ -518,13 +526,37 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } /** - * Base paths of this relation. For partitioned relations, it should be either root directories + * Paths of this relation. For partitioned relations, it should be root directories * of all partition directories. * * @since 1.4.0 */ def paths: Array[String] + /** + * Contains a set of paths that are considered as the base dirs of the input datasets. + * The partitioning discovery logic will make sure it will stop when it reaches any + * base path. By default, the paths of the dataset provided by users will be base paths. + * For example, if a user uses `sqlContext.read.parquet("/path/something=true/")`, the base path + * will be `/path/something=true/`, and the returned DataFrame will not contain a column of + * `something`. If users want to override the basePath. They can set `basePath` in the options + * to pass the new base path to the data source. + * For the above example, if the user-provided base path is `/path/`, the returned + * DataFrame will have the column of `something`. + */ + private def basePaths: Set[Path] = { + val userDefinedBasePath = parameters.get("basePath").map(basePath => Set(new Path(basePath))) + userDefinedBasePath.getOrElse { + // If the user does not provide basePath, we will just use paths. + val pathSet = paths.toSet + pathSet.map(p => new Path(p)) + }.map { hdfsPath => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val fs = hdfsPath.getFileSystem(hadoopConf) + hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + } + override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray override def sizeInBytes: Long = cachedLeafStatuses().map(_.getLen).sum @@ -558,7 +590,10 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio userDefinedPartitionColumns match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( - leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, typeInference = false) + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = false, + basePaths = basePaths) // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. @@ -576,8 +611,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio case _ => // user did not provide a partitioning schema - PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled()) + PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled(), + basePaths = basePaths) } } @@ -660,7 +698,6 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema - val codegenEnabled = this.codegenEnabled val needConversion = this.needConversion val requiredOutput = requiredColumns.map { col => @@ -677,11 +714,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled) { + val buildProjection = GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) - } else { - () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) - } val projectedRows = { val mutableProjection = buildProjection() @@ -834,7 +868,7 @@ private[sql] object HadoopFsRelation extends Logging { def listLeafFilesInParallel( paths: Array[String], hadoopConf: Configuration, - sparkContext: SparkContext): Set[FileStatus] = { + sparkContext: SparkContext): mutable.LinkedHashSet[FileStatus] = { logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") val serializableConfiguration = new SerializableConfiguration(hadoopConf) @@ -854,9 +888,10 @@ private[sql] object HadoopFsRelation extends Logging { status.getAccessTime) }.collect() - fakeStatuses.map { f => + val hadoopFakeStatuses = fakeStatuses.map { f => new FileStatus( f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) - }.toSet + } + mutable.LinkedHashSet(hadoopFakeStatuses: _*) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 40bff57a17a03..d191b50fa802e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -65,6 +65,13 @@ public void testExecution() { Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } + @Test + public void testCollectAndTake() { + DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Assert.assertEquals(3, df.select("key").collectAsList().size()); + Assert.assertEquals(2, df.select("key").takeAsList(2).size()); + } + /** * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a9493d576d179..eb6fa1e72e27b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -29,10 +29,9 @@ import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.catalyst.encoders.Encoder; -import org.apache.spark.sql.catalyst.encoders.Encoder$; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.GroupedDataset; import org.apache.spark.sql.test.TestSQLContext; @@ -42,7 +41,6 @@ public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; private transient TestSQLContext context; - private transient Encoder$ e = Encoder$.MODULE$; @Before public void setUp() { @@ -67,35 +65,43 @@ private Tuple2 tuple2(T1 t1, T2 t2) { @Test public void testCollect() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, e.STRING()); - String[] collected = (String[]) ds.collect(); - Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + Dataset ds = context.createDataset(data, Encoders.STRING()); + List collected = ds.collectAsList(); + Assert.assertEquals(Arrays.asList("hello", "world"), collected); + } + + @Test + public void testTake() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + List collected = ds.takeAsList(1); + Assert.assertEquals(Arrays.asList("hello"), collected); } @Test public void testCommonOperation() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); - Dataset filtered = ds.filter(new Function() { + Dataset filtered = ds.filter(new FilterFunction() { @Override - public Boolean call(String v) throws Exception { + public boolean call(String v) throws Exception { return v.startsWith("h"); } }); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset mapped = ds.map(new Function() { + Dataset mapped = ds.map(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); - Dataset parMapped = ds.mapPartitions(new FlatMapFunction, String>() { + Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { @Override public Iterable call(Iterator it) throws Exception { List ls = new LinkedList(); @@ -104,7 +110,7 @@ public Iterable call(Iterator it) throws Exception { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset flatMapped = ds.flatMap(new FlatMapFunction() { @@ -116,7 +122,7 @@ public Iterable call(String s) throws Exception { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), flatMapped.collectAsList()); @@ -126,9 +132,9 @@ public Iterable call(String s) throws Exception { public void testForeach() { final Accumulator accum = jsc.accumulator(0); List data = Arrays.asList("a", "b", "c"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); - ds.foreach(new VoidFunction() { + ds.foreach(new ForeachFunction() { @Override public void call(String s) throws Exception { accum.add(1); @@ -140,68 +146,84 @@ public void call(String s) throws Exception { @Test public void testReduce() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, e.INT()); + Dataset ds = context.createDataset(data, Encoders.INT()); - int reduced = ds.reduce(new Function2() { + int reduced = ds.reduce(new ReduceFunction() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; } }); Assert.assertEquals(6, reduced); - - int folded = ds.fold(1, new Function2() { - @Override - public Integer call(Integer v1, Integer v2) throws Exception { - return v1 * v2; - } - }); - Assert.assertEquals(6, folded); } @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, e.STRING()); - GroupedDataset grouped = ds.groupBy(new Function() { + Dataset ds = context.createDataset(data, Encoders.STRING()); + GroupedDataset grouped = ds.groupBy(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); + + Dataset mapped = grouped.map(new MapGroupFunction() { + @Override + public String call(Integer key, Iterator values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, Encoders.STRING()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + Dataset flatMapped = grouped.flatMap( + new FlatMapGroupFunction() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public Iterable call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); + while (values.hasNext()) { + sb.append(values.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); + + Dataset> reduced = grouped.reduce(new ReduceFunction() { + @Override + public String call(String v1, String v2) throws Exception { + return v1 + v2; + } + }); + + Assert.assertEquals( + Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), + reduced.collectAsList()); List data2 = Arrays.asList(2, 6, 10); - Dataset ds2 = context.createDataset(data2, e.INT()); - GroupedDataset grouped2 = ds2.groupBy(new Function() { + Dataset ds2 = context.createDataset(data2, Encoders.INT()); + GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { @Override public Integer call(Integer v) throws Exception { return v / 2; } - }, e.INT()); + }, Encoders.INT()); Dataset cogrouped = grouped.cogroup( grouped2, - new Function3, Iterator, Iterator>() { + new CoGroupFunction() { @Override - public Iterator call( - Integer key, - Iterator left, - Iterator right) throws Exception { + public Iterable call( + Integer key, + Iterator left, + Iterator right) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (left.hasNext()) { sb.append(left.next()); @@ -210,10 +232,10 @@ public Iterator call( while (right.hasNext()) { sb.append(right.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); } @@ -221,21 +243,22 @@ public Iterator call( @Test public void testGroupByColumn() { List data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, e.STRING()); - GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); + Dataset ds = context.createDataset(data, Encoders.STRING()); + GroupedDataset grouped = + ds.groupBy(length(col("value"))).asKey(Encoders.INT()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Dataset mapped = grouped.map( + new MapGroupFunction() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); } @@ -243,11 +266,11 @@ public Iterator call(Integer key, Iterator data) throws Exceptio @Test public void testSelect() { List data = Arrays.asList(2, 6); - Dataset ds = context.createDataset(data, e.INT()); + Dataset ds = context.createDataset(data, Encoders.INT()); Dataset> selected = ds.select( - expr("value + 1").as(e.INT()), - col("value").cast("string").as(e.STRING())); + expr("value + 1"), + col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), @@ -257,14 +280,14 @@ public void testSelect() { @Test public void testSetOperation() { List data = Arrays.asList("abc", "abc", "xyz"); - Dataset ds = context.createDataset(data, e.STRING()); + Dataset ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals( Arrays.asList("abc", "xyz"), sort(ds.distinct().collectAsList().toArray(new String[0]))); List data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset ds2 = context.createDataset(data2, e.STRING()); + Dataset ds2 = context.createDataset(data2, Encoders.STRING()); Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -286,9 +309,9 @@ private > List sort(T[] data) { @Test public void testJoin() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, e.INT()).as("a"); + Dataset ds = context.createDataset(data, Encoders.INT()).as("a"); List data2 = Arrays.asList(2, 3, 4); - Dataset ds2 = context.createDataset(data2, e.INT()).as("b"); + Dataset ds2 = context.createDataset(data2, Encoders.INT()).as("b"); Dataset> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -299,26 +322,28 @@ public void testJoin() { @Test public void testTupleEncoder() { - Encoder> encoder2 = e.tuple(e.INT(), e.STRING()); + Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); Dataset> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); - Encoder> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); + Encoder> encoder3 = + Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List> data3 = Arrays.asList(new Tuple3(1, 2L, "a")); Dataset> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder> encoder4 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List> data4 = Arrays.asList(new Tuple4(1, "b", 2L, "a")); Dataset> ds4 = context.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder> encoder5 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), + Encoders.BOOLEAN()); List> data5 = Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); Dataset> ds5 = @@ -330,7 +355,7 @@ public void testTupleEncoder() { public void testNestedTupleEncoder() { // test ((int, string), string) Encoder, String>> encoder = - e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); + Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); Dataset, String>> ds = context.createDataset(data, encoder); @@ -338,7 +363,8 @@ public void testNestedTupleEncoder() { // test (int, (string, string, long)) Encoder>> encoder2 = - e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); List>> data2 = Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); Dataset>> ds2 = @@ -347,7 +373,8 @@ public void testNestedTupleEncoder() { // test (int, ((string, long), string)) Encoder, String>>> encoder3 = - e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING())); List, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset, String>>> ds3 = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index fa559c9c64005..3eae3f6d85066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.{Project, TungstenProject} +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -563,6 +563,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { df.select(monotonicallyIncreasingId()), Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil ) + checkAnswer( + df.select(expr("monotonically_increasing_id()")), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) } test("sparkPartitionId") { @@ -615,8 +619,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { - case project: Project => project - case tungstenProject: TungstenProject => tungstenProject + case tungstenProject: Project => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2e679e7bc4e0a..432e8d17623a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -162,6 +162,31 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("multiple column distinct count") { + val df1 = Seq( + ("a", "b", "c"), + ("a", "b", "c"), + ("a", "b", "d"), + ("x", "y", "z"), + ("x", "q", null.asInstanceOf[String])) + .toDF("key1", "key2", "key3") + + checkAnswer( + df1.agg(countDistinct('key1, 'key2)), + Row(3) + ) + + checkAnswer( + df1.agg(countDistinct('key1, 'key2, 'key3)), + Row(3) + ) + + checkAnswer( + df1.groupBy('key1).agg(countDistinct('key2, 'key3)), + Seq(Row("a", 2), Row("x", 1)) + ) + } + test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( @@ -170,7 +195,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("stddev") { - val testData2ADev = math.sqrt(4 / 5.0) + val testData2ADev = math.sqrt(4.0 / 5.0) checkAnswer( testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) @@ -180,7 +205,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), - Row(null, null, null)) + Row(Double.NaN, Double.NaN, Double.NaN)) } test("zero sum") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3a3f19af1473b..aff9efe4b2b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -308,10 +308,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null)) ) - val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b") - assert(intercept[AnalysisException] { - df2.selectExpr("sort_array(a)").collect() - }.getMessage().contains("does not support sorting array of type array")) + val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b") + checkAnswer( + df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"), + Seq( + Row( + Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)), + Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null))) + ) val df3 = Seq(("xxx", "x")).toDF("a", "b") assert(intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala new file mode 100644 index 0000000000000..0c23d142670c1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class DataFramePivotSuite extends QueryTest with SharedSQLContext{ + import testImplicits._ + + test("pivot courses with literals") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + .agg(sum($"earnings")), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("pivot year with literals") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with literals and multiple aggregations") { + checkAnswer( + courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java")) + .agg(sum($"earnings"), avg($"earnings")), + Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) :: + Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil + ) + } + + test("pivot year with string values (cast)") { + checkAnswer( + courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot year with int values") { + checkAnswer( + courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot courses with no values") { + // Note Java comes before dotNet in sorted order + checkAnswer( + courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")), + Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil + ) + } + + test("pivot year with no values") { + checkAnswer( + courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("pivot max values inforced") { + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + intercept[RuntimeException]( + courseSales.groupBy($"year").pivot($"course") + ) + sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f3a7aa280367a..35cdab50bdec9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -459,7 +459,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val emptyDescribeResult = Seq( Row("count", "0", "0"), Row("mean", null, null), - Row("stddev", null, null), + Row("stddev", "NaN", "NaN"), Row("min", null, null), Row("max", null, null)) @@ -621,11 +621,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-6899: type should match when using codegen") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - } + checkAnswer(decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } test("SPARK-7133: Implement struct, array, and map field accessor") { @@ -844,31 +840,16 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - val df = testData.select(rand(33)) - assert(df.showString(5) == df.showString(5)) - } + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) // We will reuse the same Expression object for LocalRelation. - val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) - assert(df.showString(5) == df.showString(5)) + val df1 = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df1.showString(5) == df1.showString(5)) } test("SPARK-8609: local DataFrame with random columns should return same value after sort") { - // Make sure we can pass this test for both codegen mode and interpreted mode. - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } - - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) - } + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) // We will reuse the same Expression object for LocalRelation. val df = (1 to 10).map(Tuple1.apply).toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 7ae12a7895f7e..68e99d6a6b816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -31,52 +31,46 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("test simple types") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") - assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) - } + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) } test("test struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val struct = Row(1, 2L, 3.0F, 3.0) - val data = sparkContext.parallelize(Seq(Row(1, struct))) + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sparkContext.parallelize(Seq(Row(1, struct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(struct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) } test("test nested struct type") { - withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val innerStruct = Row(1, "abcd") - val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) - val schema = new StructType() - .add("a", IntegerType) - .add("b", - new StructType() - .add("b1", IntegerType) - .add("b2", LongType) - .add("b3", FloatType) - .add("b4", DoubleType) - .add("b5", new StructType() - .add("b5a", IntegerType) - .add("b5b", StringType)) - .add("b6", StringType)) + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) - assert(df.select("b").first() === Row(outerStruct)) - } + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala new file mode 100644 index 0000000000000..46f9f077fe7f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.expressions.Aggregator + +/** An `Aggregator` that adds up any numeric type returned by the given function. */ +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + val numeric = implicitly[Numeric[N]] + + override def zero: N = numeric.zero + + override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + + override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) + + override def finish(reduction: N): N = reduction +} + +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable { + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 +} + +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] + with Serializable { + + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def finish(reduction: (Long, Long)): (Long, Long) = reduction +} + +case class AggData(a: Int, b: String) +object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable { + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + override def zero: Int = 0 + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + override def reduce(b: Int, a: AggData): Int = b + a.a + + /** + * Transform the output of the reduction. + */ + override def finish(reduction: Int): Int = reduction + + /** + * Merge two intermediate values + */ + override def merge(b1: Int, b2: Int): Int = b1 + b2 +} + +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + test("typed aggregation: TypedAggregator") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum(_._2)), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum(_._2), + expr("sum(_2)").as[Int], + count("*")), + ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + } + + test("typed aggregation: complex case") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + TypedAverage.toColumn), + ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + } + + test("typed aggregation: complex result type") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + ComplexResultAgg.toColumn), + ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) + } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkAnswer( + ds.select(sum((i: Int) => i)), + 11) + checkAnswer( + ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), + 11 -> 22) + } + + test("typed aggregation: class input") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 3) + } + + test("typed aggregation: class input with reordering") { + val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData] + + checkAnswer( + ds.select(ClassInputAgg.toColumn), + 1) + + checkAnswer( + ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn), + (1.0, 1)) + + checkAnswer( + ds.groupBy(_.b).agg(ClassInputAgg.toColumn), + ("one", 1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index e3b0346f857d3..63b00975e4eb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -75,11 +75,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { assert(ds.reduce(_ + _) == 6) } - test("fold") { - val ds = Seq(1, 2, 3).toDS() - assert(ds.fold(0)(_ + _) == 6) - } - test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() val grouped = ds.groupBy(_ % 2) @@ -88,16 +83,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 0, 1) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroups { case (g, iter) => + val agged = grouped.map { case (g, iter) => val name = if (g == 0) "even" else "odd" - Iterator((name, iter.size)) + (name, iter.size) } checkAnswer( agged, ("even", 5), ("odd", 6)) } + + test("groupBy function, flatMap") { + val ds = Seq("a", "b", "c", "xyz", "hello").toDS() + val grouped = ds.groupBy(_.length) + val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + + checkAnswer( + agged, + "1", "abc", "3", "xyz", "5", "hello") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d61e17edc64ed..c23dd46d3767b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -61,6 +61,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) } + test("as case class - take") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) + } + test("map") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( @@ -137,11 +142,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } - test("fold") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) - } - test("joinWith, flat schema") { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") @@ -198,60 +198,69 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1)) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g._1, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns, mapGroups") { + test("groupBy function, fatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + + checkAnswer( + agged, + "a", "30", "b", "3", "c", "1") + } + + test("groupBy function, reduce") { + val ds = Seq("abc", "xyz", "hello").toDS() + val agged = ds.groupBy(_.length).reduce(_ + _) + + checkAnswer( + agged, + 3 -> "abcxyz", 5 -> "hello") + } + + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g.getString(0), iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey, mapGroups") { + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey tuple, mapGroups") { + test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) } - test("groupBy columns asKey class, mapGroups") { + test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, @@ -313,4 +322,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") checkAnswer(joined, ("2", 2)) } + + test("toString") { + val ds = Seq((1, 2)).toDS() + assert(ds.toString == "[_1: int, _2: int]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 9080c53c491ac..1266d534cc5b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -444,6 +444,27 @@ class DateFunctionsSuite extends QueryTest with SharedSQLContext { Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) } + test("to_unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.selectExpr("to_unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("to_unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + } + test("datediff") { val df = Seq( (Date.valueOf("2015-07-24"), Timestamp.valueOf("2015-07-24 01:00:00"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a9ca46cab067d..9a3c262e9485d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -44,8 +44,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { - case j: ShuffledHashJoin => j - case j: ShuffledHashOuterJoin => j case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j case j: BroadcastHashOuterJoin => j @@ -96,75 +94,39 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - Seq( - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", - classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashJoin]), - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", - classOf[ShuffledHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } } - test("SortMergeJoin shouldn't work on unsortable columns") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - Seq( - ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } - } +// ignore("SortMergeJoin shouldn't work on unsortable columns") { +// Seq( +// ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) +// ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } +// } test("broadcasted hash join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") - for (sortMergeJoinEnabled <- Seq(true, false)) { - withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { - Seq( - ("SELECT * FROM testData join testData2 ON key = a", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key = 2", - classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", - classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } - } - } + Seq( + ("SELECT * FROM testData join testData2 ON key = a", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", - classOf[SortMergeOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - Seq( - ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", - classOf[ShuffledHashOuterJoin]), - ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } - } + Seq( + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", + classOf[SortMergeOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", + classOf[BroadcastHashOuterJoin]), + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", + classOf[BroadcastHashOuterJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } sql("UNCACHE TABLE testData") } @@ -279,16 +241,17 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - |SELECT l.N, count(*) - |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY l.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: Nil) + |SELECT l.N, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """. + stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) checkAnswer( sql( @@ -343,7 +306,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY l.a """.stripMargin), - Row(null, 6)) + Row(null, + 6)) checkAnswer( sql( @@ -352,7 +316,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - Row(1, 1) :: + Row(1 + , 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -396,14 +361,16 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) - // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. + // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join + // operator. checkAnswer( sql( """ - |SELECT l.a, count(*) - |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY l.a - """.stripMargin), + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """. + stripMargin), Row(null, 10)) checkAnswer( @@ -413,7 +380,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) |GROUP BY r.N """.stripMargin), - Row(1, 1) :: + Row + (1, 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -428,7 +396,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) |GROUP BY l.N """.stripMargin), - Row(1, 1) :: + Row(1 + , 1) :: Row(2, 1) :: Row(3, 1) :: Row(4, 1) :: @@ -439,10 +408,11 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - |SELECT r.a, count(*) - |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY r.a - """.stripMargin), + |SELECT r.a, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """. + stripMargin), Row(null, 10)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index e3531d0d6d799..14fd56fc8c222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -41,23 +41,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("json_tuple select") { val df: DataFrame = tuples.toDF("key", "jstring") - val expected = Row("1", Row("value1", "value2", "3", null, "5.23")) :: - Row("2", Row("value12", "2", "value3", "4.01", null)) :: - Row("3", Row("value13", "2", "value33", "value44", "5.01")) :: - Row("4", Row(null, null, null, null, null)) :: - Row("5", Row("", null, null, null, null)) :: - Row("6", Row(null, null, null, null, null)) :: + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: Nil - checkAnswer(df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), expected) + checkAnswer( + df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), + expected) } test("json_tuple filter and group") { val df: DataFrame = tuples.toDF("key", "jstring") val expr = df - .selectExpr("json_tuple(jstring, 'f1', 'f2') as jt") - .where($"jt.c0".isNotNull) - .groupBy($"jt.c1") + .select(functions.json_tuple($"jstring", "f1", "f2")) + .where($"c0".isNotNull) + .groupBy($"c1") .count() val expected = Row(null, 1) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3c174efe73ffe..b5417b195f396 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.catalyst.encoders.Encoder abstract class QueryTest extends PlanTest { @@ -83,18 +82,21 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.encoder} - |${ds.encoder.constructExpression.treeString} + |${ds.resolvedTEncoder} + |${ds.resolvedTEncoder.fromRowExpression.treeString} |${ds.queryExecution} """.stripMargin, e) } if (decoded != expectedAnswer.toSet) { + val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted + val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted + + val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") fail( s"""Decoded objects do not match expected objects: - |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted} - |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted} - |${ds.encoder.constructExpression.treeString} + |$comparision + |${ds.resolvedTEncoder.fromRowExpression.treeString} """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3de277a79a52c..167aea87de077 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -237,34 +237,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-8828 sum should return null if all input values are null") { - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) } private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -285,8 +261,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("aggregation with codegen") { - val originalValue = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) // Prepare a table that we can group some rows. sqlContext.table("testData") .unionAll(sqlContext.table("testData")) @@ -340,13 +314,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) - // STDDEV - testCodeGen( - "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", - (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) - testCodeGen( - "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", - Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -367,11 +334,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", - Row(null, null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -507,29 +473,22 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("literal in agg grouping expressions") { - def literalInAggTest(): Unit = { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) - } + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - literalInAggTest() - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - literalInAggTest() - } + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) } test("aggregates with nulls") { @@ -598,12 +557,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } - } - test("limit") { checkAnswer( sql("SELECT * FROM testData LIMIT 10"), @@ -1655,12 +1608,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("aggregation with codegen updates peak execution memory") { - withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { - testCodeGen( - "SELECT key, count(value) FROM testData GROUP BY key", - (1 to 100).map(i => Row(i, 1))) - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { + testCodeGen( + "SELECT key, count(value) FROM testData GROUP BY key", + (1 to 100).map(i => Row(i, 1))) } } @@ -1813,10 +1764,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. // This bug will be triggered when Tungsten is enabled and there are multiple // SortMergeJoin operators executed in the same task. - val confs = - SQLConf.SORTMERGE_JOIN.key -> "true" :: - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: - SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + val confs = SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: Nil withSQLConf(confs: _*) { val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") val df2 = @@ -2001,4 +1949,52 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) } } + + test("Common subexpression elimination") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a229e5814df89..f602f2fb89ca5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,16 +21,12 @@ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import scala.beans.{BeanInfo, BeanProperty} -import com.clearspring.analytics.stream.cardinality.HyperLogLog - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -134,25 +130,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } - test("HyperLogLogUDT") { - val hyperLogLogUDT = HyperLogLogUDT - val hyperLogLog = new HyperLogLog(0.4) - (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) - - val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) - assert(actual.cardinality() === hyperLogLog.cardinality()) - assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) - } - - test("OpenHashSetUDT") { - val openHashSetUDT = new OpenHashSetUDT(IntegerType) - val set = new OpenHashSet[Int] - (1 to 10).foreach(i => set.add(i)) - - val actual = openHashSetUDT.deserialize(openHashSetUDT.serialize(set)) - assert(actual.iterator.toSet === set.iterator.toSet) - } - test("UDTs with JSON") { val data = Seq( "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", @@ -176,7 +153,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") assert(new MyDenseVectorUDT().typeName === "mydensevector") - assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } test("Catalyst type converter null handling for UDTs") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2076c573b56c1..dfec139985f73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.execution.joins.{SortMergeJoin, BroadcastHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext { private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = sqlContext.planner import planner._ - val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val plannedOption = Aggregation(query).headOption val planned = plannedOption.getOrElse( fail(s"Could query play aggregation query $query. Is it an aggregation query?")) @@ -97,10 +97,10 @@ class PlannerSuite extends SharedSQLContext { """.stripMargin).queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + assert(sortMergeJoins.isEmpty, "Should not use sort merge join") } } @@ -150,16 +150,30 @@ class PlannerSuite extends SharedSQLContext { val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val sortMergeJoins = planned.collect { case join: SortMergeJoin => join } assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + assert(sortMergeJoins.isEmpty, "Should not use sort merge join") sqlContext.clearCache() } } } + test("SPARK-11390 explain should print PushedFilters of PhysicalRDD") { + withTempPath { file => + val path = file.getCanonicalPath + testData.write.parquet(path) + val df = sqlContext.read.parquet(path) + sqlContext.registerDataFrameAsTable(df, "testPushed") + + withTempTable("testPushed") { + val exp = sql("select * from testPushed where key = 15").queryExecution.executedPlan + assert(exp.toString.contains("PushedFilter: [EqualTo(key,15)]")) + } + } + } + test("efficient limit -> project -> sort") { { val query = @@ -365,7 +379,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") } } @@ -381,7 +395,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.nonEmpty) { + if (outputPlan.collect { case s: Sort => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") } } @@ -398,7 +412,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: TungstenSort => true; case s: Sort => true }.isEmpty) { + if (outputPlan.collect { case s: Sort => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala new file mode 100644 index 0000000000000..9575d26fd123f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + + +/** + * A reference sort implementation used to compare against our normal sort. + */ +case class ReferenceSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( + TaskContext.get(), ordering = Some(ordering)) + sorter.insertAll(iterator.map(r => (r.copy(), null))) + val baseIterator = sorter.iterator.map(_._1) + val context = TaskContext.get() + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index b3fceeab64cfe..6876ab0f02b10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -33,9 +33,9 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 847c188a30333..e5d34be4c65e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -17,15 +17,22 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{RandomDataGenerator, Row} + +/** + * Test sorting. Many of the test cases generate random data and compares the sorted result with one + * sorted by a reference implementation ([[ReferenceSort]]). + */ class SortSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.localSeqToDataFrameHolder - // This test was originally added as an example of how to use [[SparkPlanTest]]; - // it's not designed to be a comprehensive test of ExternalSort. test("basic sorting using ExternalSort") { val input = Seq( @@ -36,14 +43,66 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('a.asc :: 'b.asc :: Nil, global = true, child = child), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + (child: SparkPlan) => Sort('b.asc :: 'a.asc :: Nil, global = true, child = child), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } + + test("sort followed by limit") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)), + (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + Sort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + test("sorting updates peak execution memory") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child = child), + (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child), + sortAnswers = false) + } + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()) + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + checkThatPlansAgree( + inputDf, + p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), + ReferenceSort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala deleted file mode 100644 index 7a0f0dfd2b7f1..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.Random - -import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ - -/** - * A test suite that generates randomized data to test the [[TungstenSort]] operator. - */ -class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) - } - - override def afterAll(): Unit = { - try { - sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED) - } finally { - super.afterAll() - } - } - - test("sort followed by limit") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } - - test("sorting does not crash for large inputs") { - val sortOrder = 'a.asc :: Nil - val stringLength = 1024 * 1024 * 2 - checkThatPlansAgree( - Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - - test("sorting updates peak execution memory") { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), - (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child), - sortAnswers = false) - } - } - - // Test sorting on different data types - for ( - dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); - randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) - ) { - test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = sqlContext.createDataFrame( - sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - assert(TungstenSort.supportsSchema(inputDf.schema)) - checkThatPlansAgree( - inputDf, - plan => ConvertToSafe( - TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala new file mode 100644 index 0000000000000..4cc0a3a9585d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Test cases for various [[JSONOptions]]. + */ +class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { + + test("allowComments off") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowComments on") { + val str = """{'name': /* hello */ 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowComments", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowSingleQuotes off") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowSingleQuotes on") { + val str = """{'name': 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowUnquotedFieldNames off") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowUnquotedFieldNames on") { + val str = """{name: 'Reynold Xin'}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.first().getString(0) == "Reynold Xin") + } + + test("allowNumericLeadingZeros off") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowNumericLeadingZeros on") { + val str = """{"age": 0018}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getLong(0) == 18) + } + + // The following two tests are not really working - need to look into Jackson's + // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. + ignore("allowNonNumericNumbers off") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + ignore("allowNonNumericNumbers on") { + val str = """{"age": NaN}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + + assert(df.schema.head.name == "age") + assert(df.first().getDouble(0).isNaN) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 28b8f02bdf87f..6042b1178affe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -588,7 +588,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { relation.isInstanceOf[JSONRelation], "The DataFrame returned by jsonFile should be based on JSONRelation.") assert(relation.asInstanceOf[JSONRelation].paths === Array(path)) - assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001)) + assert(relation.asInstanceOf[JSONRelation].options.samplingRatio === (0.49 +- 0.001)) val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = @@ -597,7 +597,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) assert(relationWithSchema.schema === schema) - assert(relationWithSchema.samplingRatio > 0.99) + assert(relationWithSchema.options.samplingRatio > 0.99) } test("Loading a JSON dataset from a text file") { @@ -1165,31 +1165,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("JSONRelation equality test") { val relation0 = new JSONRelation( Some(empty), - 1.0, - false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None)(sqlContext) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), - 1.0, - false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), - 0.5, - false, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None, + parameters = Map("samplingRatio" -> "0.5"))(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), - 1.0, - false, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(sqlContext) + None, + None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1232,7 +1229,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema(empty, 1.0, "") + val emptySchema = InferSchema.infer(empty, "", JSONOptions()) assert(StructType(Seq()) === emptySchema) } @@ -1256,7 +1253,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema(emptyRecords, 1.0, "") + val emptySchema = InferSchema.infer(emptyRecords, "", JSONOptions()) assert(StructType(Seq()) === emptySchema) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c24c9f025dad7..458786f77af3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -54,12 +54,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val analyzedPredicate = query.queryExecution.optimizedPlan.collect { + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters - }.flatten - assert(analyzedPredicate.nonEmpty) + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined) - val selectedFilters = analyzedPredicate.flatMap(DataSourceStrategy.translateFilter) + val selectedFilters = maybeAnalyzedPredicate.flatMap(DataSourceStrategy.translateFilter) assert(selectedFilters.nonEmpty) selectedFilters.foreach { pred => @@ -294,7 +294,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(path).filter("part = 1"), + sqlContext.read.parquet(dir.getCanonicalPath).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -311,7 +311,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(path).filter("a > 0 and (part = 0 or a > 1)"), + sqlContext.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), (2 to 3).map(i => Row(i, i.toString, 1))) } } @@ -336,4 +336,29 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-11661 Still pushdown filters returned by unhandledFilters") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + val df = sqlContext.read.parquet(path).filter("a = 2") + + // This is the source RDD without Spark-side filtering. + val childRDD = + df + .queryExecution + .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter] + .child + .execute() + + // The result should be single row. + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + assert(childRDD.count == 1) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 72744799897be..a148facd056a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Collections +import org.apache.parquet.column.{Encoding, ParquetProperties} + import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -91,6 +93,33 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-11694 Parquet logical types are not being tested properly") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 a(INT_8); + | required int32 b(INT_16); + | required int32 c(DATE); + | required int32 d(DECIMAL(1,0)); + | required int64 e(DECIMAL(10,0)); + | required binary f(UTF8); + | required binary g(ENUM); + | required binary h(DECIMAL(32,0)); + | required fixed_len_byte_array(32) i(DECIMAL(32,0)); + |} + """.stripMargin) + + val expectedSparkTypes = Seq(ByteType, ShortType, DateType, DecimalType(1, 0), + DecimalType(10, 0), StringType, StringType, DecimalType(32, 0), DecimalType(32, 0)) + + withTempPath { location => + val path = new Path(location.getCanonicalPath) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf) + val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + assert(sparkTypes === expectedSparkTypes) + } + } + test("string") { val data = (1 to 4).map(i => Tuple1(i.toString)) // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL @@ -206,6 +235,55 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-10113 Support for unsigned Parquet logical types") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required int32 c(UINT_32); + |} + """.stripMargin) + + withTempPath { location => + val extraMetadata = Map.empty[String, String].asJava + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val path = new Path(location.getCanonicalPath) + val footer = List( + new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())) + ).asJava + + ParquetFileWriter.writeMetadataFile(sparkContext.hadoopConfiguration, path, footer) + + val errorMessage = intercept[Throwable] { + sqlContext.read.parquet(path.toString).printSchema() + }.toString + assert(errorMessage.contains("Parquet type not supported")) + } + } + + test("SPARK-11692 Support for Parquet logical types, JSON and BSON (embedded types)") { + val parquetSchema = MessageTypeParser.parseMessageType( + """message root { + | required binary a(JSON); + | required binary b(BSON); + |} + """.stripMargin) + + withTempPath { location => + val extraMetadata = Map.empty[String, String].asJava + val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val path = new Path(location.getCanonicalPath) + val footer = List( + new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())) + ).asJava + + ParquetFileWriter.writeMetadataFile(sparkContext.hadoopConfiguration, path, footer) + + val jsonDataType = sqlContext.read.parquet(path.toString).schema(0).dataType + assert(jsonDataType === StringType) + val bsonDataType = sqlContext.read.parquet(path.toString).schema(1).dataType + assert(bsonDataType === BinaryType) + } + } + test("compression codec") { def compressionCodecFor(path: String, codecName: String): String = { val codecs = for { @@ -350,16 +428,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { """.stripMargin) withTempPath { location => - val extraMetadata = Collections.singletonMap( - CatalystReadSupport.SPARK_METADATA_KEY, sparkSchema.toString) - val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val path = new Path(location.getCanonicalPath) - - ParquetFileWriter.writeMetadataFile( - sparkContext.hadoopConfiguration, - path, - Collections.singletonList( - new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) + val conf = sparkContext.hadoopConfiguration + writeMetadata(parquetSchema, path, conf, extraMetadata) assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( @@ -489,6 +561,38 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } + test("SPARK-11044 Parquet writer version fixed as version1 ") { + // For dictionary encoding, Parquet changes the encoding types according to its writer + // version. So, this test checks one of the encoding types in order to ensure that + // the file is written with writer version2. + withTempPath { dir => + val clonedConf = new Configuration(hadoopConfiguration) + try { + // Write a Parquet file with writer version2. + hadoopConfiguration.set(ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_2_0.toString) + + // By default, dictionary encoding is enabled from Parquet 1.2.0 but + // it is enabled just in case. + hadoopConfiguration.setBoolean(ParquetOutputFormat.ENABLE_DICTIONARY, true) + val path = s"${dir.getCanonicalPath}/part-r-0.parquet" + sqlContext.range(1 << 16).selectExpr("(id % 4) AS i") + .coalesce(1).write.mode("overwrite").parquet(path) + + val blockMetadata = readFooter(new Path(path), hadoopConfiguration).getBlocks.asScala.head + val columnChunkMetadata = blockMetadata.getColumns.asScala.head + + // If the file is written with version2, this should include + // Encoding.RLE_DICTIONARY type. For version1, it is Encoding.PLAIN_DICTIONARY + assert(columnChunkMetadata.getEncodings.contains(Encoding.RLE_DICTIONARY)) + } finally { + // Manually clear the hadoop configuration for other tests. + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) + } + } + } + test("read dictionary encoded decimals written as INT32") { checkAnswer( // Decimal column in this file is encoded using plain dictionary diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 61cc0da50865c..71e9034d97792 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -66,7 +66,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions(paths.map(new Path(_)), defaultPartitionName, true, Set.empty[Path]) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -76,7 +76,37 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/_temporary/path") - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/"))) + + // Valid + paths = Seq( + "hdfs://host:9000/path/something=true/table/", + "hdfs://host:9000/path/something=true/table/_temporary", + "hdfs://host:9000/path/something=true/table/a=10/b=20", + "hdfs://host:9000/path/something=true/table/_temporary/path") + + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/something=true/table"))) + + // Valid + paths = Seq( + "hdfs://host:9000/path/table=true/", + "hdfs://host:9000/path/table=true/_temporary", + "hdfs://host:9000/path/table=true/a=10/b=20", + "hdfs://host:9000/path/table=true/_temporary/path") + + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/table=true"))) // Invalid paths = Seq( @@ -85,7 +115,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/path1") exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/path/"))) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -101,19 +135,24 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/tmp/tables/nonPartitionedTable2") exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + Set(new Path("hdfs://host:9000/tmp/tables/"))) } assert(exception.getMessage().contains("Conflicting directory structures detected")) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - assert(expected === parsePartition(new Path(path), defaultPartitionName, true)._1) + val actual = parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path])._1 + assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName, true) + parsePartition(new Path(path), defaultPartitionName, true, Set.empty[Path]) }.getMessage assert(message.contains(expected)) @@ -152,8 +191,17 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } test("parse partitions") { - def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, true) === spec) + def check( + paths: Seq[String], + spec: PartitionSpec, + rootPaths: Set[Path] = Set.empty[Path]): Unit = { + val actualSpec = + parsePartitions( + paths.map(new Path(_)), + defaultPartitionName, + true, + rootPaths) + assert(actualSpec === spec) } check(Seq( @@ -232,7 +280,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { - assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName, false) === spec) + val actualSpec = + parsePartitions(paths.map(new Path(_)), defaultPartitionName, false, Set.empty[Path]) + assert(actualSpec === spec) } check(Seq( @@ -590,6 +640,70 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } + test("SPARK-11678: Partition discovery stops at the root path of the dataset") { + withTempPath { dir => + val tablePath = new File(dir, "key=value") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + + withTempPath { dir => + val path = new File(dir, "key=value") + val tablePath = new File(path, "table") + + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + } + } + + test("use basePath to specify the root dir of a partitioned table.") { + withTempPath { dir => + val tablePath = new File(dir, "table") + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(tablePath.getCanonicalPath) + + val twoPartitionsDF = + sqlContext + .read + .option("basePath", tablePath.getCanonicalPath) + .parquet( + s"${tablePath.getCanonicalPath}/b=1", + s"${tablePath.getCanonicalPath}/b=2") + + checkAnswer(twoPartitionsDF, df.filter("b != 3")) + + intercept[AssertionError] { + sqlContext + .read + .parquet( + s"${tablePath.getCanonicalPath}/b=1", + s"${tablePath.getCanonicalPath}/b=2") + } + } + } + test("listConflictingPartitionColumns") { def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 8ffb01fc5b584..fdd7697c91f5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File +import org.apache.parquet.schema.MessageType + import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -117,6 +119,21 @@ private[sql] trait ParquetTest extends SQLTestUtils { ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) } + /** + * This is an overloaded version of `writeMetadata` above to allow writing customized + * Parquet schema. + */ + protected def writeMetadata( + parquetSchema: MessageType, path: Path, configuration: Configuration, + extraMetadata: Map[String, String] = Map.empty[String, String]): Unit = { + val extraMetadataAsJava = extraMetadata.asJava + val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" + val fileMetadata = new FileMetaData(parquetSchema, extraMetadataAsJava, createdBy) + val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) + val footer = new Footer(path, parquetMetadata) + ParquetFileWriter.writeMetadataFile(configuration, path, Seq(footer).asJava) + } + protected def readAllFootersWithoutSummaryFiles( path: Path, configuration: Configuration): Seq[Footer] = { val fs = path.getFileSystem(configuration) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index dcbfdca71acb6..5b2998c3c76d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} /** - * Test various broadcast join operators with unsafe enabled. + * Test various broadcast join operators. * * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered @@ -45,8 +45,6 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { .setAppName("testing") val sc = new SparkContext(conf) sqlContext = new SQLContext(sc) - sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 066c16e535c76..2ec17146476fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -93,20 +93,6 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) } - def makeShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - boundCondition: Option[Expression], - leftPlan: SparkPlan, - rightPlan: SparkPlan, - side: BuildSide) = { - val shuffledHashJoin = - execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan) - val filteredJoin = - boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext).apply(filteredJoin) - } - def makeSortMergeJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -143,30 +129,6 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 09e0237a7cc50..9c80714a9af4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -74,18 +74,6 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } - test(s"$testName using ShuffledHashOuterJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext).apply( - ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - if (joinType != FullOuter) { test(s"$testName using BroadcastHashOuterJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 8c2e78b2a9db7..c30327185e169 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -28,12 +28,9 @@ import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRig class HashJoinNodeSuite extends LocalNodeTest { // Test all combinations of the two dimensions: with/out unsafe and build sides - private val maybeUnsafeAndCodegen = Seq(false, true) private val buildSides = Seq(BuildLeft, BuildRight) - maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => - buildSides.foreach { buildSide => - testJoin(unsafeAndCodegen, buildSide) - } + buildSides.foreach { buildSide => + testJoin(buildSide) } /** @@ -45,18 +42,7 @@ class HashJoinNodeSuite extends LocalNodeTest { buildKeys: Seq[Expression], buildNode: LocalNode): HashedRelation = { - val isUnsafeMode = - conf.codegenEnabled && - conf.unsafeEnabled && - UnsafeProjection.canSupport(buildKeys) - - val buildSideKeyGenerator = - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildNode.output) - } else { - new InterpretedMutableProjection(buildKeys, buildNode.output) - } - + val buildSideKeyGenerator = UnsafeProjection.create(buildKeys, buildNode.output) buildNode.prepare() buildNode.open() val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) @@ -68,15 +54,10 @@ class HashJoinNodeSuite extends LocalNodeTest { /** * Test inner hash join with varying degrees of matches. */ - private def testJoin( - unsafeAndCodegen: Boolean, - buildSide: BuildSide): Unit = { - val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" - val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + private def testJoin(buildSide: BuildSide): Unit = { + val testNamePrefix = buildSide val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray val conf = new SQLConf - conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) - conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) // Actual test body def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { @@ -119,7 +100,7 @@ class HashJoinNodeSuite extends LocalNodeTest { .map { case (k, v) => (k, v, k, rightInputMap(k)) } Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val makeUnsafeNode = wrapForUnsafe(makeNode) val hashJoinNode = makeUnsafeNode(leftNode, rightNode) val actualOutput = hashJoinNode.collect().map { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 40299d9d5ee37..252f7cc8971f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -26,30 +26,21 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} class NestedLoopJoinNodeSuite extends LocalNodeTest { // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types - private val maybeUnsafeAndCodegen = Seq(false, true) private val buildSides = Seq(BuildLeft, BuildRight) private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) - maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => - buildSides.foreach { buildSide => - joinTypes.foreach { joinType => - testJoin(unsafeAndCodegen, buildSide, joinType) - } + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(buildSide, joinType) } } /** * Test outer nested loop joins with varying degrees of matches. */ - private def testJoin( - unsafeAndCodegen: Boolean, - buildSide: BuildSide, - joinType: JoinType): Unit = { - val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" - val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + private def testJoin(buildSide: BuildSide, joinType: JoinType): Unit = { + val testNamePrefix = s"$buildSide / $joinType" val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray val conf = new SQLConf - conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) - conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) // Actual test body def runTest( @@ -63,7 +54,7 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { resolveExpressions( new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) } - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val makeUnsafeNode = wrapForUnsafe(makeNode) val hashJoinNode = makeUnsafeNode(leftNode, rightNode) val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) val actualOutput = hashJoinNode.collect().map { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cdd885ba14203..5e2b4154dd7ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -21,8 +21,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ -import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ +import org.apache.xbean.asm5._ +import org.apache.xbean.asm5.Opcodes._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -41,22 +41,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { l += 1L l.add(1L) } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") } test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. val l = sparkContext.accumulator(0L) val f = () => { l += 1L } - BoxingFinder.getClassReader(f.getClass).foreach { cl => - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") - } + val cl = BoxingFinder.getClassReader(f.getClass) + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") } /** @@ -112,33 +110,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("Project metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "false", - SQLConf.TUNGSTEN_ENABLED.key -> "false") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L ->("Project", Map( - "number of rows" -> 2L))) - ) - } - } - - test("TungstenProject metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = person.select('name) - testSparkPlanMetrics(df, 1, Map( - 0L ->("TungstenProject", Map( - "number of rows" -> 2L))) - ) - } + // Assume the execution plan is + // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0) + val df = person.select('name) + testSparkPlanMetrics(df, 1, Map( + 0L ->("Project", Map( + "number of rows" -> 2L))) + ) } test("Filter metrics") { @@ -152,288 +130,152 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("Aggregate metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "false", - SQLConf.TUNGSTEN_ENABLED.key -> "false") { - // Assume the execution plan is - // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) - - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } - } - - test("SortBasedAggregate metrics") { - // Because SortBasedAggregate may skip different rows if the number of partitions is different, - // this test should use the deterministic number of partitions. - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> - // SortBasedAggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("SortBasedAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("SortBasedAggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) - - // Assume the execution plan is - // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2) - // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0) - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 3L -> ("SortBasedAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("SortBasedAggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } - } - test("TungstenAggregate metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true", - SQLConf.TUNGSTEN_ENABLED.key -> "true") { - // Assume the execution plan is - // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) - // -> TungstenAggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) + // Assume the execution plan is + // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> TungstenAggregate(nodeId = 0) + val df = testData2.groupBy().count() // 2 partitions + testSparkPlanMetrics(df, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 2L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 2L, + "number of output rows" -> 1L))) + ) - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("TungstenAggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } + // 2 partitions and each partition contains 2 keys + val df2 = testData2.groupBy('a).count() + testSparkPlanMetrics(df2, 1, Map( + 2L -> ("TungstenAggregate", Map( + "number of input rows" -> 6L, + "number of output rows" -> 4L)), + 0L -> ("TungstenAggregate", Map( + "number of input rows" -> 4L, + "number of output rows" -> 3L))) + ) } test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df, 1, Map( - 1L -> ("SortMergeJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 4L, - "number of right rows" -> 2L, - "number of output rows" -> 4L))) - ) - } + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df, 1, Map( + 1L -> ("SortMergeJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 4L, + "number of right rows" -> 2L, + "number of output rows" -> 4L))) + ) } } test("SortMergeOuterJoin metrics") { // Because SortMergeOuterJoin may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df, 1, Map( - 1L -> ("SortMergeOuterJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 6L, - "number of right rows" -> 2L, - "number of output rows" -> 8L))) - ) - - val df2 = sqlContext.sql( - "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df2, 1, Map( - 1L -> ("SortMergeOuterJoin", Map( - // It's 4 because we only read 3 rows in the first partition and 1 row in the second one - "number of left rows" -> 2L, - "number of right rows" -> 6L, - "number of output rows" -> 8L))) - ) - } - } - } - - test("BroadcastHashJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") - // Assume the execution plan is - // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = df1.join(broadcast(df2), "key") - testSparkPlanMetrics(df, 2, Map( - 1L -> ("BroadcastHashJoin", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, - "number of output rows" -> 2L))) - ) - } - } - - test("ShuffledHashJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> ShuffledHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") - testSparkPlanMetrics(df, 1, Map( - 1L -> ("ShuffledHashJoin", Map( - "number of left rows" -> 6L, - "number of right rows" -> 2L, - "number of output rows" -> 4L))) - ) - } - } - } - - test("ShuffledHashOuterJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { - val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") - val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { // Assume the execution plan is - // ... -> ShuffledHashOuterJoin(nodeId = 0) - val df = df1.join(df2, $"key" === $"key2", "left_outer") + // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( - 0L -> ("ShuffledHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 5L))) - ) - - val df3 = df1.join(df2, $"key" === $"key2", "right_outer") - testSparkPlanMetrics(df3, 1, Map( - 0L -> ("ShuffledHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 6L))) + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 6L, + "number of right rows" -> 2L, + "number of output rows" -> 8L))) ) - val df4 = df1.join(df2, $"key" === $"key2", "outer") - testSparkPlanMetrics(df4, 1, Map( - 0L -> ("ShuffledHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 7L))) + val df2 = sqlContext.sql( + "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") + testSparkPlanMetrics(df2, 1, Map( + 1L -> ("SortMergeOuterJoin", Map( + // It's 4 because we only read 3 rows in the first partition and 1 row in the second one + "number of left rows" -> 2L, + "number of right rows" -> 6L, + "number of output rows" -> 8L))) ) } } + test("BroadcastHashJoin metrics") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key", "value") + // Assume the execution plan is + // ... -> BroadcastHashJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = df1.join(broadcast(df2), "key") + testSparkPlanMetrics(df, 2, Map( + 1L -> ("BroadcastHashJoin", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) + } + test("BroadcastHashOuterJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") - val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") - // Assume the execution plan is - // ... -> BroadcastHashOuterJoin(nodeId = 0) - val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 5L))) - ) + val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") + val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastHashOuterJoin(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 5L))) + ) - val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") - testSparkPlanMetrics(df3, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( - "number of left rows" -> 3L, - "number of right rows" -> 4L, - "number of output rows" -> 6L))) - ) - } + val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") + testSparkPlanMetrics(df3, 2, Map( + 0L -> ("BroadcastHashOuterJoin", Map( + "number of left rows" -> 3L, + "number of right rows" -> 4L, + "number of output rows" -> 6L))) + ) } test("BroadcastNestedLoopJoin metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 left JOIN testDataForJoin ON " + - "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") - testSparkPlanMetrics(df, 3, Map( - 1L -> ("BroadcastNestedLoopJoin", Map( - "number of left rows" -> 12L, // left needs to be scanned twice - "number of right rows" -> 2L, - "number of output rows" -> 12L))) - ) - } + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.registerTempTable("testDataForJoin") + withTempTable("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = sqlContext.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of left rows" -> 12L, // left needs to be scanned twice + "number of right rows" -> 2L, + "number of output rows" -> 12L))) + ) } } test("BroadcastLeftSemiJoinHash metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") - // Assume the execution plan is - // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) - val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastLeftSemiJoinHash", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, - "number of output rows" -> 2L))) - ) - } + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> BroadcastLeftSemiJoinHash(nodeId = 0) + val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("BroadcastLeftSemiJoinHash", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) } test("LeftSemiJoinHash metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") // Assume the execution plan is @@ -449,19 +291,17 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("LeftSemiJoinBNL metrics") { - withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") { - val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") - val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") - // Assume the execution plan is - // ... -> LeftSemiJoinBNL(nodeId = 0) - val df = df1.join(df2, $"key" < $"key2", "leftsemi") - testSparkPlanMetrics(df, 2, Map( - 0L -> ("LeftSemiJoinBNL", Map( - "number of left rows" -> 2L, - "number of right rows" -> 4L, - "number of output rows" -> 2L))) - ) - } + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value") + // Assume the execution plan is + // ... -> LeftSemiJoinBNL(nodeId = 0) + val df = df1.join(df2, $"key" < $"key2", "leftsemi") + testSparkPlanMetrics(df, 2, Map( + 0L -> ("LeftSemiJoinBNL", Map( + "number of left rows" -> 2L, + "number of right rows" -> 4L, + "number of output rows" -> 2L))) + ) } test("CartesianProduct metrics") { @@ -516,7 +356,7 @@ private class BoxingFinder( method: MethodIdentifier[_] = null, val boxingInvokes: mutable.Set[String] = mutable.Set.empty, visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) - extends ClassVisitor(ASM4) { + extends ClassVisitor(ASM5) { private val primitiveBoxingClassName = Set("java/lang/Long", @@ -533,11 +373,12 @@ private class BoxingFinder( MethodVisitor = { if (method != null && (method.name != name || method.desc != desc)) { // If method is specified, skip other methods. - return new MethodVisitor(ASM4) {} + return new MethodVisitor(ASM5) {} } - new MethodVisitor(ASM4) { - override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + new MethodVisitor(ASM5) { + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean) { if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { if (primitiveBoxingClassName.contains(owner)) { // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) @@ -552,10 +393,9 @@ private class BoxingFinder( if (!visitedMethods.contains(m)) { // Keep track of visited methods to avoid potential infinite cycles visitedMethods += m - BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => - visitedMethods += m - cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) - } + val cl = BoxingFinder.getClassReader(classOfMethodOwner) + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) } } } @@ -565,22 +405,14 @@ private class BoxingFinder( private object BoxingFinder { - def getClassReader(cls: Class[_]): Option[ClassReader] = { + def getClassReader(cls: Class[_]): ClassReader = { val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" val resourceStream = cls.getResourceAsStream(className) val baos = new ByteArrayOutputStream(128) // Copy data over, before delegating to ClassReader - // else we can run out of open file handles. Utils.copyStream(resourceStream, baos, true) - // ASM4 doesn't support Java 8 classes, which requires ASM5. - // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), - // then ClassReader will throw IllegalArgumentException, - // However, since this is only for testing, it's safe to skip these classes. - try { - Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) - } catch { - case _: IllegalArgumentException => None - } + new ClassReader(new ByteArrayInputStream(baos.toByteArray)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 2cad964e55b2b..398b8a1a661c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -254,7 +254,11 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3, Set("a", "b", "c")) testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0, Set("a", "b", "c")) - testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10, Set("a", "b", "c")) + testPushDown( + "SELECT * FROM oneToTenFiltered WHERE b = 1", + 10, + Set("a", "b", "c"), + Set(EqualTo("b", 1))) testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3, Set("a", "b", "c")) testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4, Set("a", "b", "c")) @@ -283,12 +287,23 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic | WHERE a + b > 9 | AND b < 16 | AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo') - """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b")) + """.stripMargin.split("\n").map(_.trim).mkString(" "), + 3, + Set("a", "b"), + Set(LessThan("b", 16))) def testPushDown( - sqlString: String, - expectedCount: Int, - requiredColumnNames: Set[String]): Unit = { + sqlString: String, + expectedCount: Int, + requiredColumnNames: Set[String]): Unit = { + testPushDown(sqlString, expectedCount, requiredColumnNames, Set.empty[Filter]) + } + + def testPushDown( + sqlString: String, + expectedCount: Int, + requiredColumnNames: Set[String], + expectedUnhandledFilters: Set[Filter]): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { @@ -300,15 +315,13 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic val rawCount = rawPlan.execute().count() assert(ColumnsRequired.set === requiredColumnNames) - assert { - val table = caseInsensitiveContext.table("oneToTenFiltered") - val relation = table.queryExecution.logical.collectFirst { - case LogicalRelation(r, _) => r - }.get + val table = caseInsensitiveContext.table("oneToTenFiltered") + val relation = table.queryExecution.logical.collectFirst { + case LogicalRelation(r, _) => r + }.get - // `relation` should be able to handle all pushed filters - relation.unhandledFilters(FiltersPushed.list.toArray).isEmpty - } + assert( + relation.unhandledFilters(FiltersPushed.list.toArray).toSet === expectedUnhandledFilters) if (rawCount != expectedCount) { fail( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c9791879ec74c..3eaa817f9c0b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -53,4 +53,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { Utils.deleteRecursively(path) } + + test("partitioned columns should appear at the end of schema") { + withTempPath { f => + val path = f.getAbsolutePath + Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) + assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 520dea7f7dd92..abad0d7eaaedf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self => df } + protected lazy val courseSales: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + CourseSales("dotNET", 2012, 10000) :: + CourseSales("Java", 2012, 20000) :: + CourseSales("dotNET", 2012, 5000) :: + CourseSales("dotNET", 2013, 48000) :: + CourseSales("Java", 2013, 30000) :: Nil).toDF() + df.registerTempTable("courseSales") + df + } + /** * Initialize all test data such that all temp tables are properly registered. */ @@ -295,4 +306,5 @@ private[sql] object SQLTestData { case class Person(id: Int, name: String, age: Int) case class Salary(personId: Int, salary: Double) case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) + case class CourseSales(course: String, year: Int, earnings: Double) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3fa5c8528b602..fcf039916913a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -27,7 +27,7 @@ import scala.concurrent.{Await, Promise} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterAll import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} @@ -36,21 +36,26 @@ import org.apache.spark.{Logging, SparkFunSuite} * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary * Hive metastore and warehouse. */ -class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { +class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() val scratchDirPath = Utils.createTempDir() - before { + override def beforeAll(): Unit = { + super.beforeAll() warehousePath.delete() metastorePath.delete() scratchDirPath.delete() } - after { - warehousePath.delete() - metastorePath.delete() - scratchDirPath.delete() + override def afterAll(): Unit = { + try { + warehousePath.delete() + metastorePath.delete() + scratchDirPath.delete() + } finally { + super.afterAll() + } } /** @@ -79,6 +84,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val jdbcUrl = s"jdbc:derby:;databaseName=$metastorePath;create=true" s"""$cliScript | --master local + | --driver-java-options -Dderby.system.durability=test + | --conf spark.ui.enabled=false | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ff8ca0150649d..eb1895f263d70 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -41,6 +41,7 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} @@ -462,6 +463,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { assert(conf.get("spark.sql.hive.version") === Some("1.2.1")) } } + + test("SPARK-11595 ADD JAR with input path having URL scheme") { + withJdbcStatement { statement => + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + + Seq( + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) + + val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + + assert(rs1.next()) + assert(rs1.getString(1) === "Function: udtf_count2") + + assert(rs1.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs1.getString(1) + } + + assert(rs1.next()) + assert(rs1.getString(1) === "Usage: To be added.") + + val dataPath = "../hive/src/test/resources/data/files/kv1.txt" + + Seq( + s"CREATE TABLE test_udtf(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf" + ).foreach(statement.execute) + + val rs2 = statement.executeQuery( + "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc") + + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + } + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2d72b959af134..0c473799cc991 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -454,7 +454,15 @@ class HiveContext private[hive]( // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) + new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), this) { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // Hive Registry need current database to lookup function + // TODO: the current database of executionHive should be consistent with metadataHive + executionHive.withHiveState { + super.lookupFunction(name, children) + } + } + } // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer // can't access the SessionState of metadataHive. @@ -576,7 +584,6 @@ class HiveContext private[hive]( HiveTableScans, DataSinks, Scripts, - HashAggregation, Aggregation, LeftSemiJoin, EquiJoinSelection, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ab88c1e68fd72..091caab921fe9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -38,6 +38,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -1508,9 +1509,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) /* Aggregate Functions */ - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => + Count(Literal(1)).toAggregateExpression() /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => @@ -1819,6 +1821,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val explode = "(?i)explode".r + val jsonTuple = "(?i)json_tuple".r def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head @@ -1831,6 +1834,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => (Explode(nodeToExpr(child)), attributes) + case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => + (JsonTuple(children.map(nodeToExpr)), attributes) + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3dce86c480747..598ccdeee4ad2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} import java.util.{Map => JMap} -import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.language.reflectiveCalls @@ -33,9 +32,10 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.ql.{Driver, metadata} import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.VersionInfo -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, SparkException, Logging} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -150,6 +150,27 @@ private[hive] class ClientWrapper( val original = Thread.currentThread().getContextClassLoader // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) + + // Set up kerberos credentials for UserGroupInformation.loginUser within + // current class loader + // Instead of using the spark conf of the current spark context, a new + // instance of SparkConf is needed for the original value of spark.yarn.keytab + // and spark.yarn.principal set in SparkSubmit, as yarn.Client resets the + // keytab configuration for the link name in distributed cache + val sparkConf = new SparkConf + if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { + val principalName = sparkConf.get("spark.yarn.principal") + val keytabFileName = sparkConf.get("spark.yarn.keytab") + if (!new File(keytabFileName).exists()) { + throw new SparkException(s"Keytab file: ${keytabFileName}" + + " specified in spark.yarn.keytab does not exist") + } else { + logInfo("Attempting to login to Kerberos" + + s" using principal: ${principalName} and keytab: ${keytabFileName}") + UserGroupInformation.loginUserFromKeytab(principalName, keytabFileName) + } + } + val ret = try { val initialConf = new HiveConf(classOf[SessionState]) // HiveConf is a Hadoop Configuration, which has a field of classLoader and @@ -548,7 +569,15 @@ private[hive] class ClientWrapper( } def addJar(path: String): Unit = { - clientLoader.addJar(path) + val uri = new Path(path).toUri + val jarURL = if (uri.getScheme == null) { + // `path` is a local file path without a URL scheme + new File(path).toURI.toURL + } else { + // `path` is a URL with a scheme + uri.toURL + } + clientLoader.addJar(jarURL) runSqlHive(s"ADD JAR $path") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index f99c3ed2ae987..e041e0d8e5ae8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util -import scala.collection.mutable import scala.language.reflectiveCalls import scala.util.Try @@ -30,10 +29,9 @@ import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitUtils -import org.apache.spark.util.{MutableURLClassLoader, Utils} - import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ private[hive] object IsolatedClientLoader { @@ -190,9 +188,8 @@ private[hive] class IsolatedClientLoader( new NonClosableMutableURLClassLoader(isolatedClassLoader) } - private[hive] def addJar(path: String): Unit = synchronized { - val jarURL = new java.io.File(path).toURI.toURL - classLoader.addURL(jarURL) + private[hive] def addJar(path: URL): Unit = synchronized { + classLoader.addURL(path) } /** The isolated client interface to Hive. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a9db70119d011..e6fe2ad5f23b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -46,17 +46,23 @@ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ -private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) +private[hive] class HiveFunctionRegistry( + underlying: analysis.FunctionRegistry, + hiveContext: HiveContext) extends analysis.FunctionRegistry with HiveInspectors { - def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) + def getFunctionInfo(name: String): FunctionInfo = { + hiveContext.executionHive.withHiveState { + FunctionRegistry.getFunctionInfo(name) + } + } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { Try(underlying.lookupFunction(name, children)).getOrElse { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + Option(getFunctionInfo(name.toLowerCase)).getOrElse( throw new AnalysisException(s"undefined function $name")) val functionClassName = functionInfo.getFunctionClass.getName @@ -110,7 +116,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) override def lookupFunction(name: String): Option[ExpressionInfo] = { underlying.lookupFunction(name).orElse( Try { - val info = FunctionRegistry.getFunctionInfo(name) + val info = getFunctionInfo(name) val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) if (annotation != null) { Some(new ExpressionInfo( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 45de567039760..1136670b7a0eb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -157,7 +157,7 @@ private[sql] class OrcRelation( override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) + extends HadoopFsRelation(maybePartitionSpec, parameters) with Logging { private[sql] def this( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 2e5cae415e54b..9864acf765265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll @@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with private var testData: DataFrame = _ override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") + testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") } @@ -52,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } + test("collect functions") { + checkAnswer( + testData.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + testData.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 528a7398b10df..a330362b4e1d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.plans.logical.Generate import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite @@ -183,4 +185,15 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assertError("select interval '.1111111111' second", "nanosecond 1111111111 outside range") } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val plan = HiveQl.parseSql( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 10e4ae2c50308..24a3afee148c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -23,7 +23,7 @@ import java.util.Date import scala.collection.mutable.ArrayBuffer -import org.scalatest.Matchers +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -42,14 +42,14 @@ import org.apache.spark.util.{ResetSystemProperties, Utils} class HiveSparkSubmitSuite extends SparkFunSuite with Matchers - // This test suite sometimes gets extremely slow out of unknown reason on Jenkins. Here we - // add a timestamp to provide more diagnosis information. + with BeforeAndAfterEach with ResetSystemProperties with Timeouts { // TODO: rewrite these or mark them as slow tests to be run sparingly - def beforeAll() { + override def beforeEach() { + super.beforeEach() System.setProperty("spark.testing", "true") } @@ -66,6 +66,7 @@ class HiveSparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -79,6 +80,7 @@ class HiveSparkSubmitSuite "--master", "local-cluster[2,1,1024]", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -93,6 +95,7 @@ class HiveSparkSubmitSuite val args = Seq( "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", "--class", "Main", testJar) runSparkSubmit(args) @@ -104,6 +107,9 @@ class HiveSparkSubmitSuite "--class", SPARK_9757.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } @@ -114,6 +120,9 @@ class HiveSparkSubmitSuite "--class", SPARK_11009.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", unusedJar.toString) runSparkSubmit(args) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9bb32f11b76bd..f775f1e955876 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -166,7 +166,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, - "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") + "SortMergeJoin should be planned when BroadcastHashJoin is turned off") sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ea80060e370e0..6dde79f74d3d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -66,14 +66,40 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun } } +class LongProductSum extends UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ - var originalUseAggregate2: Boolean = _ - override def beforeAll(): Unit = { - originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -106,6 +132,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") + val data3 = Seq[(Seq[Integer], Integer, Integer)]( + (Seq[Integer](1, 1), 10, -10), + (Seq[Integer](null), -60, 60), + (Seq[Integer](1, 1), 30, -30), + (Seq[Integer](1), 30, 30), + (Seq[Integer](2), 1, 1), + (null, -10, 10), + (Seq[Integer](2, 3), -1, null), + (Seq[Integer](2, 3), 1, 1), + (Seq[Integer](2, 3, 4), null, 1), + (Seq[Integer](null), 100, -10), + (Seq[Integer](3), null, 3), + (null, null, null), + (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + data3.write.saveAsTable("agg3") + val emptyDF = sqlContext.createDataFrame( sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) @@ -114,13 +156,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // Register UDAFs sqlContext.udf.register("mydoublesum", new MyDoubleSum) sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("longProductSum", new LongProductSum) } override def afterAll(): Unit = { sqlContext.sql("DROP TABLE IF EXISTS agg1") sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.sql("DROP TABLE IF EXISTS agg3") sqlContext.dropTempTable("emptyTable") - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { @@ -240,6 +283,41 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(100, null) :: Row(null, 3) :: Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT key + |FROM agg3 + """.stripMargin), + Row(Seq[Integer](1, 1)) :: + Row(Seq[Integer](null)) :: + Row(Seq[Integer](1)) :: + Row(Seq[Integer](2)) :: + Row(null) :: + Row(Seq[Integer](2, 3)) :: + Row(Seq[Integer](2, 3, 4)) :: + Row(Seq[Integer](3)) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg3 + |GROUP BY value1, key + """.stripMargin), + Row(10, Seq[Integer](1, 1)) :: + Row(-60, Seq[Integer](null)) :: + Row(30, Seq[Integer](1, 1)) :: + Row(30, Seq[Integer](1)) :: + Row(1, Seq[Integer](2)) :: + Row(-10, null) :: + Row(-1, Seq[Integer](2, 3)) :: + Row(1, Seq[Integer](2, 3)) :: + Row(null, Seq[Integer](2, 3, 4)) :: + Row(100, Seq[Integer](null)) :: + Row(null, Seq[Integer](3)) :: + Row(null, null) :: Nil) } test("case in-sensitive resolution") { @@ -447,73 +525,124 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("single distinct column set") { - // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. - checkAnswer( - sqlContext.sql( - """ - |SELECT - | min(distinct value1), - | sum(distinct value1), - | avg(value1), - | avg(value2), - | max(distinct value1) - |FROM agg2 - """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + Seq(true, false).foreach { specializeSingleDistinctAgg => + val conf = + (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key, + specializeSingleDistinctAgg.toString) + withSQLConf(conf) { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) + } + } + } + test("single distinct multiple columns set") { checkAnswer( sqlContext.sql( """ |SELECT - | mydoubleavg(distinct value1), - | avg(value1), - | avg(value2), | key, - | mydoubleavg(value1 - 1), - | mydoubleavg(distinct value1) * 0.1, - | avg(value1 + value2) + | count(distinct value1, value2) |FROM agg2 |GROUP BY key """.stripMargin), - Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: - Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: - Row(null, null, 3.0, 3, null, null, null) :: - Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + Row(null, 3) :: + Row(1, 3) :: + Row(2, 1) :: + Row(3, 0) :: Nil) + } + test("multiple distinct multiple columns sets") { checkAnswer( sqlContext.sql( """ |SELECT | key, - | mydoubleavg(distinct value1), - | mydoublesum(value2), - | mydoublesum(distinct value1), - | mydoubleavg(distinct value1), - | mydoubleavg(value1) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: - Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: - Row(3, null, 3.0, null, null, null) :: - Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT + | count(distinct value1), + | sum(distinct value1), + | count(distinct value2), + | sum(distinct value2), + | count(distinct value1, value2), + | longProductSum(distinct value1, value2), | count(value1), + | sum(value1), + | count(value2), + | sum(value2), + | longProductSum(value1, value2), | count(*), - | count(1), - | count(DISTINCT value1), - | key + | count(1) |FROM agg2 |GROUP BY key """.stripMargin), - Row(3, 3, 3, 2, 1) :: - Row(3, 4, 4, 2, 2) :: - Row(0, 2, 2, 0, 3) :: - Row(3, 4, 4, 3, null) :: Nil) + Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } test("test count") { @@ -657,48 +786,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) - - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - val errorMessage = intercept[SparkException] { - val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") - val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) - }.getMessage - assert(errorMessage.contains("java.lang.UnsupportedOperationException: " + - "Corr only supports the new AggregateExpression2")) - } - } - - test("test Last implemented based on AggregateExpression1") { - // TODO: Remove this test once we remove AggregateExpression1. - import org.apache.spark.sql.functions._ - val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - - checkAnswer( - df.groupBy("i").agg(last("j")), - df - ) - } - } - - test("error handling") { - withSQLConf("spark.sql.useAggregate2" -> "false") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | mydoublesum(value), - | mydoubleavg(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - } } test("no aggregation function (SPARK-11486)") { @@ -760,67 +847,25 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val df = sqlContext.createDataFrame(rdd, schema) val allColumns = df.schema.fields.map(f => col(f.name)) - val expectedAnaswer = + val expectedAnswer = data .find(r => r.getInt(0) == 50) .getOrElse(fail("A row with id 50 should be the expected answer.")) checkAnswer( df.groupBy().agg(udaf(allColumns: _*)), // udaf returns a Row as the output value. - Row(expectedAnaswer) + Row(expectedAnswer) ) } } } -class SortBasedAggregationQuerySuite extends AggregationQuerySuite { - - var originalUnsafeEnabled: Boolean = _ - - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "false") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } -} - -class TungstenAggregationQuerySuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ +class TungstenAggregationQuerySuite extends AggregationQuerySuite - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - } -} class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { - var originalUnsafeEnabled: Boolean = _ - - override def beforeAll(): Unit = { - originalUnsafeEnabled = sqlContext.conf.unsafeEnabled - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") - super.beforeAll() - } - - override def afterAll(): Unit = { - super.afterAll() - sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) - sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") - } - override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => sqlContext.setConf( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 94162da4eae1a..a7b7ad0093915 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -37,8 +37,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", - "== Physical Plan ==", - "Code Generation") + "== Physical Plan ==") } test("explain create table command") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fc72e3c7dc6aa..f0a7a6cc7a1e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,22 +20,19 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin - import scala.util.Try -import org.scalatest.BeforeAndAfter - import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.{SparkException, SparkFiles} case class TestData(a: Int, b: String) @@ -927,7 +924,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -961,10 +958,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("CREATE TEMPORARY FUNCTION") { val funcJar = TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath - sql(s"ADD JAR $funcJar") + val jarURL = s"file://$funcJar" + sql(s"ADD JAR $jarURL") sql( """CREATE TEMPORARY FUNCTION udtf_count2 AS - | 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'""".stripMargin) + |'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) assert(sql("DESCRIBE FUNCTION udtf_count2").count > 1) sql("DROP TEMPORARY FUNCTION udtf_count2") } @@ -1235,6 +1234,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("lookup hive UDF in another thread") { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + var success = false + val t = new Thread("test") { + override def run(): Unit = { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + success = true + } + } + t.start() + t.join() + assert(success) + } + createQueryTest("select from thrift based table", "SELECT * from src_thrift") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 197e9bfb02c4e..5bd323ea096a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -43,7 +43,9 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { test("[SPARK-2210] boolean cast on boolean value should be removed") { val q = "select cast(cast(key=0 as boolean) as boolean) from src" - val project = TestHive.sql(q).queryExecution.executedPlan.collect { case e: Project => e }.head + val project = TestHive.sql(q).queryExecution.executedPlan.collect { + case e: Project => e + }.head // No cast expression introduced project.transformAllExpressions { case c: Cast => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 5f9a447759b48..9deb1a6db15ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.{DataInput, DataOutput} +import java.io.{PrintWriter, File, DataInput, DataOutput} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration @@ -28,11 +28,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton - import org.apache.spark.util.Utils + case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. @@ -44,7 +45,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest with TestHiveSingleton { +class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { import hiveContext.{udf, sql} import hiveContext.implicits._ @@ -92,44 +93,36 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { } test("Max/Min on named_struct") { - def testOrderInStruct(): Unit = { - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", key, - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - - // nested struct cases - checkAnswer(sql( - """ - |SELECT max(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_498"))) - checkAnswer(sql( - """ - |SELECT min(named_struct( - | "key", named_struct( - "key", key, - "value", value), - | "value", value)).value FROM src - """.stripMargin), Seq(Row("val_0"))) - } - val codegenDefault = hiveContext.getConf(SQLConf.CODEGEN_ENABLED) - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, true) - testOrderInStruct() - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, false) - testOrderInStruct() - hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", key, + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) + + // nested struct cases + checkAnswer(sql( + """ + |SELECT max(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_498"))) + checkAnswer(sql( + """ + |SELECT min(named_struct( + | "key", named_struct( + "key", key, + "value", value), + | "value", value)).value FROM src + """.stripMargin), Seq(Row("val_0"))) } test("SPARK-6409 UDAF Average test") { @@ -356,6 +349,94 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { sqlContext.dropTempTable("testUDF") } + + test("SPARK-11522 select input_file_name from non-parquet table"){ + + withTempDir { tempDir => + + // EXTERNAL OpenCSVSerde table pointing to LOCATION + + val file1 = new File(tempDir + "/data1") + val writer1 = new PrintWriter(file1) + writer1.write("1,2") + writer1.close() + + val file2 = new File(tempDir + "/data2") + val writer2 = new PrintWriter(file2) + writer2.write("1,2") + writer2.close() + + sql( + s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) + ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + WITH SERDEPROPERTIES ( + \"separatorChar\" = \",\", + \"quoteChar\" = \"\\\"\", + \"escapeChar\" = \"\\\\\") + LOCATION '$tempDir' + """) + + val answer1 = + sql("SELECT input_file_name() FROM csv_table").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count1 = sql("SELECT input_file_name() FROM csv_table").distinct().count() + assert(count1 == 2) + sql("DROP TABLE csv_table") + + // EXTERNAL pointing to LOCATION + + sql( + s"""CREATE EXTERNAL TABLE external_t5 (c1 int, c2 int) + ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + LOCATION '$tempDir' + """) + + val answer2 = + sql("SELECT input_file_name() as file FROM external_t5").head().getString(0) + assert(answer1.contains("data1") || answer1.contains("data2")) + + val count2 = sql("SELECT input_file_name() as file FROM external_t5").distinct().count + assert(count2 == 2) + sql("DROP TABLE external_t5") + } + + withTempDir { tempDir => + + // External parquet pointing to LOCATION + + val parquetLocation = tempDir + "/external_parquet" + sql("SELECT 1, 2").write.parquet(parquetLocation) + + sql( + s"""CREATE EXTERNAL TABLE external_parquet(c1 int, c2 int) + STORED AS PARQUET + LOCATION '$parquetLocation' + """) + + val answer3 = + sql("SELECT input_file_name() as file FROM external_parquet").head().getString(0) + assert(answer3.contains("external_parquet")) + + val count3 = sql("SELECT input_file_name() as file FROM external_parquet").distinct().count + assert(count3 == 1) + sql("DROP TABLE external_parquet") + } + + // Non-External parquet pointing to /tmp/... + + sql("CREATE TABLE parquet_tmp(c1 int, c2 int) " + + " STORED AS parquet " + + " AS SELECT 1, 2") + + val answer4 = + sql("SELECT input_file_name() as file FROM parquet_tmp").head().getString(0) + assert(answer4.contains("parquet_tmp")) + + val count4 = sql("SELECT input_file_name() as file FROM parquet_tmp").distinct().count + assert(count4 == 1) + sql("DROP TABLE parquet_tmp") + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index af48d478953b4..3427152b2da02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1428,4 +1428,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) } } + + test("SPARK-11453: append data to partitioned table") { + withTable("tbl11453") { + Seq("1" -> "10", "2" -> "20").toDF("i", "j") + .write.partitionBy("i").saveAsTable("tbl11453") + + Seq("3" -> "30").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil) + + // make sure case sensitivity is correct. + Seq("4" -> "40").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) + } + } + + test("SPARK-11590: use native json_tuple in lateral view") { + checkAnswer(sql( + """ + |SELECT a, b + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin), Row("value1", "12")) + + // we should use `c0`, `c1`... as the name of fields if no alias is provided, to follow hive. + checkAnswer(sql( + """ + |SELECT c0, c1 + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt + """.stripMargin), Row("value1", "12")) + + // we can also use `json_tuple` in project list. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2') + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + """.stripMargin), Row("value1", "12")) + + // we can also mix `json_tuple` with other project expressions. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2'), 3.14, str + |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test + """.stripMargin), Row("value1", "12", 3.14, "hello")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index e2d754e806403..e866493ee6c96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,7 +23,7 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{execution, AnalysisException, SaveMode} +import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -155,4 +155,23 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) } } + + test("SPARK-11500: Not deterministic order of columns when using merging schemas.") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { + withTempPath { dir => + val pathOne = s"${dir.getCanonicalPath}/part=1" + Seq(1, 1).zipWithIndex.toDF("a", "b").write.parquet(pathOne) + val pathTwo = s"${dir.getCanonicalPath}/part=2" + Seq(1, 1).zipWithIndex.toDF("c", "b").write.parquet(pathTwo) + val pathThree = s"${dir.getCanonicalPath}/part=3" + Seq(1, 1).zipWithIndex.toDF("d", "b").write.parquet(pathThree) + + // The schema consists of the leading columns of the first part-file + // in the lexicographic order. + assert(sqlContext.read.parquet(dir.getCanonicalPath).schema.map(_.name) + === Seq("a", "b", "c", "d", "part")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 9251a69f31a47..81af684ba0bf1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -248,7 +248,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat projections = Seq('c, 'p), filter = 'a < 3 && 'p > 0, requiredColumns = Seq("c", "a"), - pushedFilters = Nil, + pushedFilters = Seq(LessThan("a", 3)), inconvertibleFilters = Nil, unhandledFilters = Seq('a < 3), partitioningFilters = Seq('p > 0) @@ -327,7 +327,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat projections = Seq('b, 'p), filter = 'c > "val_7" && 'b < 18 && 'p > 0, requiredColumns = Seq("b"), - pushedFilters = Seq(GreaterThan("c", "val_7")), + pushedFilters = Seq(GreaterThan("c", "val_7"), LessThan("b", 18)), inconvertibleFilters = Nil, unhandledFilters = Seq('b < 18), partitioningFilters = Seq('p > 0) @@ -344,7 +344,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat projections = Seq('b, 'p), filter = 'a % 2 === 0 && 'c > "val_7" && 'b < 18 && 'p > 0, requiredColumns = Seq("b", "a"), - pushedFilters = Seq(GreaterThan("c", "val_7")), + pushedFilters = Seq(GreaterThan("c", "val_7"), LessThan("b", 18)), inconvertibleFilters = Seq('a % 2 === 0), unhandledFilters = Seq('b < 18), partitioningFilters = Seq('p > 0) @@ -361,7 +361,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat projections = Seq('b, 'p), filter = 'a > 7 && 'a < 9, requiredColumns = Seq("b", "a"), - pushedFilters = Seq(GreaterThan("a", 7)), + pushedFilters = Seq(GreaterThan("a", 7), LessThan("a", 9)), inconvertibleFilters = Nil, unhandledFilters = Seq('a < 9), partitioningFilters = Nil diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index bdc48a383bbbf..01960fd2901b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -89,7 +89,7 @@ class SimpleTextRelation( override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends HadoopFsRelation { + extends HadoopFsRelation(parameters) { import sqlContext.sparkContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 100b97137cff0..665e87e3e3355 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -486,6 +486,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val df = sqlContext.read .format(dataSourceName) .option("dataSchema", dataSchema.json) + .option("basePath", file.getCanonicalPath) .load(s"${file.getCanonicalPath}/p1=*/p2=???") val expectedPaths = Set( diff --git a/streaming/pom.xml b/streaming/pom.xml index 145c8a7321c05..435e16db13ab4 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -93,6 +93,11 @@ selenium-java test
    + + org.mockito + mockito-core + test + target/scala-${scala.binary.version}/classes diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala new file mode 100644 index 0000000000000..604e64fc61630 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of + * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Scala example of using `State`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example of using `State`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * Function2, State, Optional> trackStateFunc = + * new Function2, State, Optional>() { + * + * @Override + * public Optional call(Optional one, State state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * // return something + * } + * }; + * }}} + */ +@Experimental +sealed abstract class State[S] { + + /** Whether the state already exists */ + def exists(): Boolean + + /** + * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`. + * Check with `exists()` whether the state exists or not before calling `get()`. + * + * @throws java.util.NoSuchElementException If the state does not exist. + */ + def get(): S + + /** + * Update the state with a new value. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + * + * @throws java.lang.IllegalArgumentException If the state has already been removed, or is + * going to be removed + */ + def update(newState: S): Unit + + /** + * Remove the state if it exists. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + */ + def remove(): Unit + + /** + * Whether the state is timing out and going to be removed by the system after the current batch. + * This timeout can occur if timeout duration has been specified in the + * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data + * for that timeout duration. + */ + def isTimingOut(): Boolean + + /** + * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`. + */ + @inline final def getOption(): Option[S] = if (exists) Some(get()) else None + + @inline final override def toString(): String = { + getOption.map { _.toString }.getOrElse("") + } +} + +/** Internal implementation of the [[State]] interface */ +private[streaming] class StateImpl[S] extends State[S] { + + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = false + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false + + // ========= Public API ========= + override def exists(): Boolean = { + defined + } + + override def get(): S = { + if (defined) { + state + } else { + throw new NoSuchElementException("State is not set") + } + } + + override def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + state = newState + defined = true + updated = true + } + + override def isTimingOut(): Boolean = { + timingOut + } + + override def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + require(!removed, "Cannot remove the state that has already been removed") + defined = false + updated = false + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved(): Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated(): Boolean = { + updated + } + + /** + * Update the internal data and flags in `this` to the given state option. + * This method allows `this` object to be reused across many state records. + */ + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + /** + * Update the internal data and flags in `this` to the given state that is going to be timed out. + * This method allows `this` object to be reused across many state records. + */ + def wrapTiminoutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala new file mode 100644 index 0000000000000..bea5b9df20b53 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import com.google.common.base.Optional +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.ClosureCleaner +import org.apache.spark.{HashPartitioner, Partitioner} + +/** + * :: Experimental :: + * Abstract class representing all the specifications of the DStream transformation + * `trackStateByKey` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or + * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of + * this class. + * + * Example in Scala: + * {{{ + * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { + * ... + * } + * + * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * }}} + * + * Example in Java: + * {{{ + * StateSpec spec = + * StateSpec.function(trackingFunction) + * .numPartition(10); + * + * JavaTrackStateDStream emittedRecordDStream = + * javaPairDStream.trackStateByKey(spec); + * }}} + */ +@Experimental +sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { + + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + def initialState(rdd: RDD[(KeyType, StateType)]): this.type + + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type + + /** + * Set the number of partitions by which the state RDDs generated by `trackStateByKey` + * will be partitioned. Hash partitioning will be used. + */ + def numPartitions(numPartitions: Int): this.type + + /** + * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be + * be partitioned. + */ + def partitioner(partitioner: Partitioner): this.type + + /** + * Set the duration after which the state of an idle key will be removed. A key and its state is + * considered idle if it has not received any data for at least the given duration. The state + * tracking function will be called one final time on the idle states that are going to be + * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set + * to `true` in that call. + */ + def timeout(idleDuration: Duration): this.type +} + + +/** + * :: Experimental :: + * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] + * that is used for specifying the parameters of the DStream transformation `trackStateByKey` + * that is used for specifying the parameters of the DStream transformation + * `trackStateByKey` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * + * Example in Scala: + * {{{ + * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { + * ... + * } + * + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType]( + * StateSpec.function(trackingFunction).numPartitions(10)) + * }}} + * + * Example in Java: + * {{{ + * StateSpec spec = + * StateSpec.function(trackingFunction) + * .numPartition(10); + * + * JavaTrackStateDStream emittedRecordDStream = + * javaPairDStream.trackStateByKey(spec); + * }}} + */ +@Experimental +object StateSpec { + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * + * @param trackingFunction The function applied on every data item to manage the associated state + * and generate the emitted data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] + ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { + ClosureCleaner.clean(trackingFunction, checkSerializable = true) + new StateSpecImpl(trackingFunction) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * + * @param trackingFunction The function applied on every data item to manage the associated state + * and generate the emitted data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + trackingFunction: (Option[ValueType], State[StateType]) => EmittedType + ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { + ClosureCleaner.clean(trackingFunction, checkSerializable = true) + val wrappedFunction = + (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => { + Some(trackingFunction(value, state)) + } + new StateSpecImpl(wrappedFunction) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all + * the specifications of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction: + JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]): + StateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s) + Option(t.orNull) + } + StateSpec.function(trackingFunc) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]): + StateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (v: Option[ValueType], s: State[StateType]) => { + javaTrackingFunction.call(Optional.fromNullable(v.get), s) + } + StateSpec.function(trackingFunc) + } +} + + +/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */ +private[streaming] +case class StateSpecImpl[K, V, S, T]( + function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] { + + require(function != null) + + @volatile private var partitioner: Partitioner = null + @volatile private var initialStateRDD: RDD[(K, S)] = null + @volatile private var timeoutInterval: Duration = null + + override def initialState(rdd: RDD[(K, S)]): this.type = { + this.initialStateRDD = rdd + this + } + + override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + + override def numPartitions(numPartitions: Int): this.type = { + this.partitioner(new HashPartitioner(numPartitions)) + this + } + + override def partitioner(partitioner: Partitioner): this.type = { + this.partitioner = partitioner + this + } + + override def timeout(interval: Duration): this.type = { + this.timeoutInterval = interval + this + } + + // ================= Private Methods ================= + + private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function + + private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD) + + private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner) + + private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index e2aec6c2f63e7..70e32b383e458 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -28,8 +28,10 @@ import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} + import org.apache.spark.Partitioner -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} @@ -426,6 +428,48 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( ) } + /** + * :: Experimental :: + * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream + * with a continuously updated per-key state. The user-provided state tracking function is + * applied on each keyed data item along with its corresponding state. The function can choose to + * update/remove the state and return a transformed data, which forms the + * [[JavaTrackStateDStream]]. + * + * The specifications of this transformation is made through the + * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there + * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. + * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details. + * + * Example of using `trackStateByKey`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * Function2, State, Optional> trackStateFunc = + * new Function2, State, Optional>() { + * + * @Override + * public Optional call(Optional one, State state) { + * // Check if state exists, accordingly update/remove state and return transformed data + * } + * }; + * + * JavaTrackStateDStream trackStateDStream = + * keyValueDStream.trackStateByKey( + * StateSpec.function(trackStateFunc).numPartitions(10)); + * }}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state + * @tparam EmittedType Class type of the tranformed data return by the tracking function + */ + @Experimental + def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]): + JavaTrackStateDStream[K, V, StateType, EmittedType] = { + new JavaTrackStateDStream(dstream.trackStateByKey(spec)( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag)) + } + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala new file mode 100644 index 0000000000000..7bfd6bd5af759 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler.StreamingListener + +private[streaming] trait PythonStreamingListener{ + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted) { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError) { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped) { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted) { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted) { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted) { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted) { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted) { } +} + +private[streaming] class PythonStreamingListenerWrapper(listener: PythonStreamingListener) + extends JavaStreamingListener { + + /** Called when a receiver has been started */ + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + listener.onReceiverStarted(receiverStarted) + } + + /** Called when a receiver has reported an error */ + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + listener.onReceiverError(receiverError) + } + + /** Called when a receiver has been stopped */ + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + listener.onReceiverStopped(receiverStopped) + } + + /** Called when a batch of jobs has been submitted for processing. */ + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + listener.onBatchSubmitted(batchSubmitted) + } + + /** Called when processing of a batch of jobs has started. */ + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + listener.onBatchStarted(batchStarted) + } + + /** Called when processing of a batch of jobs has completed. */ + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + listener.onBatchCompleted(batchCompleted) + } + + /** Called when processing of a job of a batch has started. */ + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + listener.onOutputOperationStarted(outputOperationStarted) + } + + /** Called when processing of a job of a batch has completed. */ + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + listener.onOutputOperationCompleted(outputOperationCompleted) + } +} + +/** + * A listener interface for receiving information about an ongoing streaming computation. + */ +private[streaming] class JavaStreamingListener { + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { } +} + +/** + * Base trait for events related to JavaStreamingListener + */ +private[streaming] sealed trait JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchSubmitted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchCompleted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchStarted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationStarted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationCompleted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStarted(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverError(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStopped(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +/** + * Class having information on batches. + * + * @param batchTime Time of the batch + * @param streamIdToInputInfo A map of input stream id to its input info + * @param submissionTime Clock time of when jobs of this batch was submitted to the streaming + * scheduler queue + * @param processingStartTime Clock time of when the first job of this batch started processing. + * `-1` means the batch has not yet started + * @param processingEndTime Clock time of when the last job of this batch finished processing. `-1` + * means the batch has not yet completed. + * @param schedulingDelay Time taken for the first job of this batch to start processing from the + * time this batch was submitted to the streaming scheduler. Essentially, it + * is `processingStartTime` - `submissionTime`. `-1` means the batch has not + * yet started + * @param processingDelay Time taken for the all jobs of this batch to finish processing from the + * time they started processing. Essentially, it is + * `processingEndTime` - `processingStartTime`. `-1` means the batch has not + * yet completed. + * @param totalDelay Time taken for all the jobs of this batch to finish processing from the time + * they were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + * `-1` means the batch has not yet completed. + * @param numRecords The number of recorders received by the receivers in this batch + * @param outputOperationInfos The output operations in this batch + */ +private[streaming] case class JavaBatchInfo( + batchTime: Time, + streamIdToInputInfo: java.util.Map[Int, JavaStreamInputInfo], + submissionTime: Long, + processingStartTime: Long, + processingEndTime: Long, + schedulingDelay: Long, + processingDelay: Long, + totalDelay: Long, + numRecords: Long, + outputOperationInfos: java.util.Map[Int, JavaOutputOperationInfo]) + +/** + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + * @param metadataDescription description of this input stream + */ +private[streaming] case class JavaStreamInputInfo( + inputStreamId: Int, + numRecords: Long, + metadata: java.util.Map[String, Any], + metadataDescription: String) + +/** + * Class having information about a receiver + */ +private[streaming] case class JavaReceiverInfo( + streamId: Int, + name: String, + active: Boolean, + location: String, + executorId: String, + lastErrorMessage: String, + lastError: String, + lastErrorTime: Long) + +/** + * Class having information on output operations. + * + * @param batchTime Time of the batch + * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. + * @param description The description of this output operation. + * @param startTime Clock time of when the output operation started processing. `-1` means the + * output operation has not yet started + * @param endTime Clock time of when the output operation started processing. `-1` means the output + * operation has not yet completed + * @param failureReason Failure reason if this output operation fails. If the output operation is + * successful, this field is `null`. + */ +private[streaming] case class JavaOutputOperationInfo( + batchTime: Time, + id: Int, + name: String, + description: String, + startTime: Long, + endTime: Long, + failureReason: String) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala new file mode 100644 index 0000000000000..b109b9f1cbeae --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.streaming.scheduler._ + +/** + * A wrapper to convert a [[JavaStreamingListener]] to a [[StreamingListener]]. + */ +private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: JavaStreamingListener) + extends StreamingListener { + + private def toJavaReceiverInfo(receiverInfo: ReceiverInfo): JavaReceiverInfo = { + JavaReceiverInfo( + receiverInfo.streamId, + receiverInfo.name, + receiverInfo.active, + receiverInfo.location, + receiverInfo.executorId, + receiverInfo.lastErrorMessage, + receiverInfo.lastError, + receiverInfo.lastErrorTime + ) + } + + private def toJavaStreamInputInfo(streamInputInfo: StreamInputInfo): JavaStreamInputInfo = { + JavaStreamInputInfo( + streamInputInfo.inputStreamId, + streamInputInfo.numRecords: Long, + streamInputInfo.metadata.asJava, + streamInputInfo.metadataDescription.orNull + ) + } + + private def toJavaOutputOperationInfo( + outputOperationInfo: OutputOperationInfo): JavaOutputOperationInfo = { + JavaOutputOperationInfo( + outputOperationInfo.batchTime, + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description: String, + outputOperationInfo.startTime.getOrElse(-1), + outputOperationInfo.endTime.getOrElse(-1), + outputOperationInfo.failureReason.orNull + ) + } + + private def toJavaBatchInfo(batchInfo: BatchInfo): JavaBatchInfo = { + JavaBatchInfo( + batchInfo.batchTime, + batchInfo.streamIdToInputInfo.mapValues(toJavaStreamInputInfo(_)).asJava, + batchInfo.submissionTime, + batchInfo.processingStartTime.getOrElse(-1), + batchInfo.processingEndTime.getOrElse(-1), + batchInfo.schedulingDelay.getOrElse(-1), + batchInfo.processingDelay.getOrElse(-1), + batchInfo.totalDelay.getOrElse(-1), + batchInfo.numRecords, + batchInfo.outputOperationInfos.mapValues(toJavaOutputOperationInfo(_)).asJava + ) + } + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + javaStreamingListener.onReceiverStarted( + new JavaStreamingListenerReceiverStarted(toJavaReceiverInfo(receiverStarted.receiverInfo))) + } + + override def onReceiverError(receiverError: StreamingListenerReceiverError): Unit = { + javaStreamingListener.onReceiverError( + new JavaStreamingListenerReceiverError(toJavaReceiverInfo(receiverError.receiverInfo))) + } + + override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped): Unit = { + javaStreamingListener.onReceiverStopped( + new JavaStreamingListenerReceiverStopped(toJavaReceiverInfo(receiverStopped.receiverInfo))) + } + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + javaStreamingListener.onBatchSubmitted( + new JavaStreamingListenerBatchSubmitted(toJavaBatchInfo(batchSubmitted.batchInfo))) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + javaStreamingListener.onBatchStarted( + new JavaStreamingListenerBatchStarted(toJavaBatchInfo(batchStarted.batchInfo))) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + javaStreamingListener.onBatchCompleted( + new JavaStreamingListenerBatchCompleted(toJavaBatchInfo(batchCompleted.batchInfo))) + } + + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { + javaStreamingListener.onOutputOperationStarted(new JavaStreamingListenerOutputOperationStarted( + toJavaOutputOperationInfo(outputOperationStarted.outputOperationInfo))) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + javaStreamingListener.onOutputOperationCompleted( + new JavaStreamingListenerOutputOperationCompleted( + toJavaOutputOperationInfo(outputOperationCompleted.outputOperationInfo))) + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala new file mode 100644 index 0000000000000..f459930d0660b --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.streaming.dstream.TrackStateDStream + +/** + * :: Experimental :: + * [[JavaDStream]] representing the stream of records emitted by the tracking function in the + * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the + * stream of state snapshots, that is, the state data of all keys after a batch has updated them. + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state + * @tparam EmittedType Class of the emitted records + */ +@Experimental +class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType]( + dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType]) + extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) { + + def stateSnapshots(): JavaPairDStream[KeyType, StateType] = + new JavaPairDStream(dstream.stateSnapshots())( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1da0b0a54df07..1a6edf9473d84 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -341,7 +341,7 @@ abstract class DStream[T: ClassTag] ( // of RDD generation, else generate nothing. if (isTimeValid(time)) { - val rddOption = createRDDWithLocalProperties(time) { + val rddOption = createRDDWithLocalProperties(time, displayInnerRDDOps = false) { // Disable checks for existing output directories in jobs launched by the streaming // scheduler, since we may need to write output to an existing directory during checkpoint // recovery; see SPARK-4835 for more details. We need to have this call here because @@ -373,27 +373,52 @@ abstract class DStream[T: ClassTag] ( /** * Wrap a body of code such that the call site and operation scope * information are passed to the RDDs created in this body properly. - */ - protected def createRDDWithLocalProperties[U](time: Time)(body: => U): U = { + * @param body RDD creation code to execute with certain local properties. + * @param time Current batch time that should be embedded in the scope names + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the inner RDDs generated + * by `body` will be displayed in the UI; only the scope and callsite + * of the DStream operation that generated `this` will be displayed. + */ + protected[streaming] def createRDDWithLocalProperties[U]( + time: Time, + displayInnerRDDOps: Boolean)(body: => U): U = { val scopeKey = SparkContext.RDD_SCOPE_KEY val scopeNoOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY // Pass this DStream's operation scope and creation site information to RDDs through // thread-local properties in our SparkContext. Since this method may be called from another // DStream, we need to temporarily store any old scope and creation site information to // restore them later after setting our own. - val prevCallSite = ssc.sparkContext.getCallSite() + val prevCallSite = CallSite( + ssc.sparkContext.getLocalProperty(CallSite.SHORT_FORM), + ssc.sparkContext.getLocalProperty(CallSite.LONG_FORM) + ) val prevScope = ssc.sparkContext.getLocalProperty(scopeKey) val prevScopeNoOverride = ssc.sparkContext.getLocalProperty(scopeNoOverrideKey) try { - ssc.sparkContext.setCallSite(creationSite) + if (displayInnerRDDOps) { + // Unset the short form call site, so that generated RDDs get their own + ssc.sparkContext.setLocalProperty(CallSite.SHORT_FORM, null) + ssc.sparkContext.setLocalProperty(CallSite.LONG_FORM, null) + } else { + // Set the callsite, so that the generated RDDs get the DStream's call site and + // the internal RDD call sites do not get displayed + ssc.sparkContext.setCallSite(creationSite) + } + // Use the DStream's base scope for this RDD so we can (1) preserve the higher level // DStream operation name, and (2) share this scope with other DStreams created in the // same operation. Disallow nesting so that low-level Spark primitives do not show up. // TODO: merge callsites with scopes so we can just reuse the code there makeScope(time).foreach { s => ssc.sparkContext.setLocalProperty(scopeKey, s.toJson) - ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + if (displayInnerRDDOps) { + // Allow inner RDDs to add inner scopes + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, null) + } else { + // Do not allow inner RDDs to override the scope set by DStream + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + } } body @@ -628,7 +653,7 @@ abstract class DStream[T: ClassTag] ( */ def foreachRDD(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { val cleanedF = context.sparkContext.clean(foreachFunc, false) - this.foreachRDD((r: RDD[T], t: Time) => cleanedF(r)) + foreachRDD((r: RDD[T], t: Time) => cleanedF(r), displayInnerRDDOps = true) } /** @@ -639,7 +664,23 @@ abstract class DStream[T: ClassTag] ( // because the DStream is reachable from the outer object here, and because // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean - new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register() + foreachRDD(foreachFunc, displayInnerRDDOps = true) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + * @param foreachFunc foreachRDD function + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the RDDs generated + * in the `foreachFunc` to be displayed in the UI. If `false`, then + * only the scopes and callsites of `foreachRDD` will override those + * of the RDDs on the display. + */ + private def foreachRDD( + foreachFunc: (RDD[T], Time) => Unit, + displayInnerRDDOps: Boolean): Unit = { + new ForEachDStream(this, + context.sparkContext.clean(foreachFunc, false), displayInnerRDDOps).register() } /** @@ -730,7 +771,7 @@ abstract class DStream[T: ClassTag] ( // scalastyle:on println } } - new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() + foreachRDD(context.sparkContext.clean(foreachFunc), displayInnerRDDOps = false) } /** @@ -900,7 +941,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreachRDD(saveFunc) + this.foreachRDD(saveFunc, displayInnerRDDOps = false) } /** @@ -913,7 +954,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreachRDD(saveFunc) + this.foreachRDD(saveFunc, displayInnerRDDOps = false) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index c109ceccc6989..4410a9977c87b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -22,10 +22,19 @@ import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.scheduler.Job import scala.reflect.ClassTag +/** + * An internal DStream used to represent output operations like DStream.foreachRDD. + * @param parent Parent DStream + * @param foreachFunc Function to apply on each RDD generated by the parent DStream + * @param displayInnerRDDOps Whether the detailed callsites and scopes of the RDDs generated + * by `foreachFunc` will be displayed in the UI; only the scope and + * callsite of `DStream.foreachRDD` will be displayed. + */ private[streaming] class ForEachDStream[T: ClassTag] ( parent: DStream[T], - foreachFunc: (RDD[T], Time) => Unit + foreachFunc: (RDD[T], Time) => Unit, + displayInnerRDDOps: Boolean ) extends DStream[Unit](parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) @@ -37,8 +46,7 @@ class ForEachDStream[T: ClassTag] ( override def generateJob(time: Time): Option[Job] = { parent.getOrCompute(time) match { case Some(rdd) => - val jobFunc = () => createRDDWithLocalProperties(time) { - ssc.sparkContext.setCallSite(creationSite) + val jobFunc = () => createRDDWithLocalProperties(time, displayInnerRDDOps) { foreachFunc(rdd, time) } Some(new Job(time, jobFunc)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 71bec96d46c8d..fb691eed27e32 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,19 +24,19 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner} +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.streaming._ import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} +import org.apache.spark.{HashPartitioner, Partitioner} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) - extends Serializable -{ + extends Serializable { private[streaming] def ssc = self.ssc private[streaming] def sparkContext = self.context.sparkContext @@ -350,6 +350,44 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) ) } + /** + * :: Experimental :: + * Return a new DStream of data generated by combining the key-value data in `this` stream + * with a continuously updated per-key state. The user-provided state tracking function is + * applied on each keyed data item along with its corresponding state. The function can choose to + * update/remove the state and return a transformed data, which forms the + * [[org.apache.spark.streaming.dstream.TrackStateDStream]]. + * + * The specifications of this transformation is made through the + * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there + * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. + * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details. + * + * Example of using `trackStateByKey`: + * {{{ + * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = { + * // Check if state exists, accordingly update/remove state and return transformed data + * } + * + * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * + * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec) + * }}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state + * @tparam EmittedType Class type of the tranformed data return by the tracking function + */ + @Experimental + def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( + spec: StateSpec[K, V, StateType, EmittedType] + ): TrackStateDStream[K, V, StateType, EmittedType] = { + new TrackStateDStreamImpl[K, V, StateType, EmittedType]( + self, + spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] + ) + } + /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala new file mode 100644 index 0000000000000..98e881e6ae115 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} + +/** + * :: Experimental :: + * DStream representing the stream of records emitted by the tracking function in the + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * Additionally, it also gives access to the stream of state snapshots, that is, the state data of + * all keys after a batch has updated them. + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state data + * @tparam EmittedType Class of the emitted records + */ +@Experimental +sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag]( + ssc: StreamingContext) extends DStream[EmittedType](ssc) { + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] +} + +/** Internal implementation of the [[TrackStateDStream]] */ +private[streaming] class TrackStateDStreamImpl[ + KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag]( + dataStream: DStream[(KeyType, ValueType)], + spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType]) + extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) { + + private val internalStream = + new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec) + + override def slideDuration: Duration = internalStream.slideDuration + + override def dependencies: List[DStream[_]] = List(internalStream) + + override def compute(validTime: Time): Option[RDD[EmittedType]] = { + internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } } + } + + /** + * Forward the checkpoint interval to the internal DStream that computes the state maps. This + * to make sure that this DStream does not get checkpointed, only the internal stream. + */ + override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = { + internalStream.checkpoint(checkpointInterval) + this + } + + /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ + def stateSnapshots(): DStream[(KeyType, StateType)] = { + internalStream.flatMap { + _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable } + } + + def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass + + def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass + + def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass + + def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass +} + +/** + * A DStream that allows per-key state to be maintains, and arbitrary records to be generated + * based on updates to the state. This is the main DStream that implements the `trackStateByKey` + * operation on DStreams. + * + * @param parent Parent (key, value) stream that is the source + * @param spec Specifications of the trackStateByKey operation + * @tparam K Key type + * @tparam V Value type + * @tparam S Type of the state maintained + * @tparam E Type of the emitted data + */ +private[streaming] +class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E]) + extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) { + + persist(StorageLevel.MEMORY_ONLY) + + private val partitioner = spec.getPartitioner().getOrElse( + new HashPartitioner(ssc.sc.defaultParallelism)) + + private val trackingFunction = spec.getFunction() + + override def slideDuration: Duration = parent.slideDuration + + override def dependencies: List[DStream[_]] = List(parent) + + /** Enable automatic checkpointing */ + override val mustCheckpoint = true + + /** Method that generates a RDD for the given time */ + override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { + // Get the previous state or create a new empty state RDD + val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, validTime + ) + } + + // Compute the new state RDD with previous state RDD and partitioned data RDD + parent.getOrCompute(validTime).map { dataRDD => + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds + } + new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 5eabdf63dc8d7..080bc873fa0a8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -51,4 +51,17 @@ class TransformedDStream[U: ClassTag] ( } Some(transformedRDD) } + + /** + * Wrap a body of code such that the call site and operation scope + * information are passed to the RDDs created in this body properly. + * This has been overriden to make sure that `displayInnerRDDOps` is always `true`, that is, + * the inner scopes and callsites of RDDs generated in `DStream.transform` are always + * displayed in the UI. + */ + override protected[streaming] def createRDDWithLocalProperties[U]( + time: Time, + displayInnerRDDOps: Boolean)(body: => U): U = { + super.createRDDWithLocalProperties(time, displayInnerRDDOps = true)(body) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala new file mode 100644 index 0000000000000..7050378d0feb0 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.rdd.{MapPartitionsRDD, RDD} +import org.apache.spark.streaming.{Time, StateImpl, State} +import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} +import org.apache.spark.util.Utils +import org.apache.spark._ + +/** + * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a + * sequence of records returned by the tracking function of `trackStateByKey`. + */ +private[streaming] case class TrackStateRDDRecord[K, S, E]( + var stateMap: StateMap[K, S], var emittedRecords: Seq[E]) + +private[streaming] object TrackStateRDDRecord { + def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + prevRecord: Option[TrackStateRDDRecord[K, S, E]], + dataIterator: Iterator[(K, V)], + updateFunction: (Time, K, Option[V], State[S]) => Option[E], + batchTime: Time, + timeoutThresholdTime: Option[Long], + removeTimedoutData: Boolean + ): TrackStateRDDRecord[K, S, E] = { + // Create a new state map by cloning the previous one (if it exists) or by creating an empty one + val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } + + val emittedRecords = new ArrayBuffer[E] + val wrappedState = new StateImpl[S]() + + // Call the tracking function on each record in the data iterator, and accordingly + // update the states touched, and collect the data returned by the tracking function + dataIterator.foreach { case (key, value) => + wrappedState.wrap(newStateMap.get(key)) + val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState) + if (wrappedState.isRemoved) { + newStateMap.remove(key) + } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) { + newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) + } + emittedRecords ++= emittedRecord + } + + // Get the timed out state records, call the tracking function on each and collect the + // data returned + if (removeTimedoutData && timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + wrappedState.wrapTiminoutState(state) + val emittedRecord = updateFunction(batchTime, key, None, wrappedState) + emittedRecords ++= emittedRecord + newStateMap.remove(key) + } + } + + TrackStateRDDRecord(newStateMap, emittedRecords) + } +} + +/** + * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state + * RDD, and a partitioned keyed-data RDD + */ +private[streaming] class TrackStateRDDPartition( + idx: Int, + @transient private var prevStateRDD: RDD[_], + @transient private var partitionedDataRDD: RDD[_]) extends Partition { + + private[rdd] var previousSessionRDDPartition: Partition = null + private[rdd] var partitionedDataRDDPartition: Partition = null + + override def index: Int = idx + override def hashCode(): Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { + // Update the reference to parent split at the time of task serialization + previousSessionRDDPartition = prevStateRDD.partitions(index) + partitionedDataRDDPartition = partitionedDataRDD.partitions(index) + oos.defaultWriteObject() + } +} + + +/** + * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records. + * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a + * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking + * function of `trackStateByKey`. + * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created + * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps + * in the `prevStateRDD` to create `this` RDD + * @param trackingFunction The function that will be used to update state and return new data + * @param batchTime The time of the batch to which this RDD belongs to. Use to update + * @param timeoutThresholdTime The time to indicate which keys are timeout + */ +private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]], + private var partitionedDataRDD: RDD[(K, V)], + trackingFunction: (Time, K, Option[V], State[S]) => Option[E], + batchTime: Time, + timeoutThresholdTime: Option[Long] + ) extends RDD[TrackStateRDDRecord[K, S, E]]( + partitionedDataRDD.sparkContext, + List( + new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD), + new OneToOneDependency(partitionedDataRDD)) + ) { + + @volatile private var doFullScan = false + + require(prevStateRDD.partitioner.nonEmpty) + require(partitionedDataRDD.partitioner == prevStateRDD.partitioner) + + override val partitioner = prevStateRDD.partitioner + + override def checkpoint(): Unit = { + super.checkpoint() + doFullScan = true + } + + override def compute( + partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = { + + val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] + val prevStateRDDIterator = prevStateRDD.iterator( + stateRDDPartition.previousSessionRDDPartition, context) + val dataIterator = partitionedDataRDD.iterator( + stateRDDPartition.partitionedDataRDDPartition, context) + + val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None + val newRecord = TrackStateRDDRecord.updateRecordWithData( + prevRecord, + dataIterator, + trackingFunction, + batchTime, + timeoutThresholdTime, + removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled + ) + Iterator(newRecord) + } + + override protected def getPartitions: Array[Partition] = { + Array.tabulate(prevStateRDD.partitions.length) { i => + new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + prevStateRDD = null + partitionedDataRDD = null + } + + def setFullScan(): Unit = { + doFullScan = true + } +} + +private[streaming] object TrackStateRDD { + + def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + pairRDD: RDD[(K, S)], + partitioner: Partitioner, + updateTime: Time): TrackStateRDD[K, V, S, T] = { + + val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => + val stateMap = StateMap.create[K, S](SparkEnv.get.conf) + iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } + Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T])) + }, preservesPartitioning = true) + + val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) + + val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None + + new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + } +} + +private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) { + override protected def getPartitions: Array[Partition] = parent.partitions + override def compute(partition: Partition, context: TaskContext): Iterator[T] = { + parent.compute(partition, context).flatMap { _.emittedRecords } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 2480b4ec093e2..1ed6fb0aa9d52 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -88,8 +88,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { if (eventLoop == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") - // First, stop receiving - receiverTracker.stop(processAllReceivedData) + if (receiverTracker != null) { + // First, stop receiving + receiverTracker.stop(processAllReceivedData) + } // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index f2711d1355e60..500dc70c98506 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -22,12 +22,13 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} +import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} import org.apache.spark.{Logging, SparkConf} @@ -41,7 +42,6 @@ private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: private[streaming] case class BatchCleanupEvent(times: Seq[Time]) extends ReceivedBlockTrackerLogEvent - /** Class representing the blocks of all the streams allocated to a batch */ private[streaming] case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { @@ -82,15 +82,22 @@ private[streaming] class ReceivedBlockTracker( } /** Add received block. This event will get written to the write ahead log (if enabled). */ - def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { try { - writeToLog(BlockAdditionEvent(receivedBlockInfo)) - getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug(s"Stream ${receivedBlockInfo.streamId} received " + - s"block ${receivedBlockInfo.blockStoreResult.blockId}") - true + val writeResult = writeToLog(BlockAdditionEvent(receivedBlockInfo)) + if (writeResult) { + synchronized { + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + } else { + logDebug(s"Failed to acknowledge stream ${receivedBlockInfo.streamId} receiving " + + s"block ${receivedBlockInfo.blockStoreResult.blockId} in the Write Ahead Log.") + } + writeResult } catch { - case e: Exception => + case NonFatal(e) => logError(s"Error adding block $receivedBlockInfo", e) false } @@ -106,10 +113,12 @@ private[streaming] class ReceivedBlockTracker( (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) - writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) - timeToAllocatedBlocks(batchTime) = allocatedBlocks - lastAllocatedBatchTime = batchTime - allocatedBlocks + if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime + } else { + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + } } else { // This situation occurs when: // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, @@ -157,9 +166,12 @@ private[streaming] class ReceivedBlockTracker( require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) - writeToLog(BatchCleanupEvent(timesToCleanup)) - timeToAllocatedBlocks --= timesToCleanup - writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + if (writeToLog(BatchCleanupEvent(timesToCleanup))) { + timeToAllocatedBlocks --= timesToCleanup + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + } else { + logWarning("Failed to acknowledge batch clean up in the Write Ahead Log.") + } } /** Stop the block tracker. */ @@ -185,8 +197,8 @@ private[streaming] class ReceivedBlockTracker( logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } - lastAllocatedBatchTime = batchTime timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime } // Cleanup the batch allocations @@ -213,12 +225,20 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { - logDebug(s"Writing to log $record") - writeAheadLogOption.foreach { logManager => - logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) + logTrace(s"Writing record: $record") + try { + writeAheadLogOption.get.write(ByteBuffer.wrap(Utils.serialize(record)), + clock.getTimeMillis()) + true + } catch { + case NonFatal(e) => + logWarning(s"Exception thrown while writing record: $record to the WriteAheadLog.", e) + false } + } else { + true } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index 59df892397fe0..3b35964114c02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -30,6 +30,7 @@ case class ReceiverInfo( name: String, active: Boolean, location: String, + executorId: String, lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b183d856f50c3..ea5d12b50fcc5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.language.existentials import scala.util.{Failure, Success} @@ -437,7 +437,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged private val submitJobThreadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + ThreadUtils.newDaemonCachedThreadPool("submit-job-thread-pool")) + + private val walBatchingThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool")) + + @volatile private var active: Boolean = true override def receive: PartialFunction[Any, Unit] = { // Local messages @@ -488,7 +493,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false registerReceiver(streamId, typ, host, executorId, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => - context.reply(addBlock(receivedBlockInfo)) + if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) { + walBatchingThreadPool.execute(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + if (active) { + context.reply(addBlock(receivedBlockInfo)) + } else { + throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + } + } + }) + } else { + context.reply(addBlock(receivedBlockInfo)) + } case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) @@ -599,6 +616,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def onStop(): Unit = { submitJobThreadPool.shutdownNow() + active = false + walBatchingThreadPool.shutdown() } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala index ab0a84f05214d..4dc5bb9c3bfbe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -49,6 +49,7 @@ private[streaming] case class ReceiverTrackingInfo( name.getOrElse(""), state == ReceiverState.ACTIVE, location = runningExecutor.map(_.host).getOrElse(""), + executorId = runningExecutor.map(_.executorId).getOrElse(""), lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), lastError = errorInfo.map(_.lastError).getOrElse(""), lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 96d943e75d272..4588b2163cd44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -402,7 +402,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
    Status
    -
    Location
    +
    Executor ID / Host
    Last Error Time
    Last Error Message @@ -430,7 +430,11 @@ private[ui] class StreamingPage(parent: StreamingTab) val receiverActive = receiverInfo.map { info => if (info.active) "ACTIVE" else "INACTIVE" }.getOrElse(emptyCell) - val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) + val receiverLocation = receiverInfo.map { info => + val executorId = if (info.executorId.isEmpty) emptyCell else info.executorId + val location = if (info.location.isEmpty) emptyCell else info.location + s"$executorId / $location" + }.getOrElse(emptyCell) val receiverLastError = receiverInfo.map { info => val msg = s"${info.lastErrorMessage} - ${info.lastError}" if (msg.size > 100) msg.take(97) + "..." else msg diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala new file mode 100644 index 0000000000000..6e6ed8d819721 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue +import java.util.{Iterator => JIterator} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils + +/** + * A wrapper for a WriteAheadLog that batches records before writing data. Handles aggregation + * during writes, and de-aggregation in the `readAll` method. The end consumer has to handle + * de-aggregation after the `read` method. In addition, the `WriteAheadLogRecordHandle` returned + * after the write will contain the batch of records rather than individual records. + * + * When writing a batch of records, the `time` passed to the `wrappedLog` will be the timestamp + * of the latest record in the batch. This is very important in achieving correctness. Consider the + * following example: + * We receive records with timestamps 1, 3, 5, 7. We use "log-1" as the filename. Once we receive + * a clean up request for timestamp 3, we would clean up the file "log-1", and lose data regarding + * 5 and 7. + * + * This means the caller can assume the same write semantics as any other WriteAheadLog + * implementation despite the batching in the background - when the write() returns, the data is + * written to the WAL and is durable. To take advantage of the batching, the caller can write from + * multiple threads, each of which will stay blocked until the corresponding data has been written. + * + * All other methods of the WriteAheadLog interface will be passed on to the wrapped WriteAheadLog. + */ +private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: SparkConf) + extends WriteAheadLog with Logging { + + import BatchedWriteAheadLog._ + + private val walWriteQueue = new LinkedBlockingQueue[Record]() + + // Whether the writer thread is active + @volatile private var active: Boolean = true + private val buffer = new ArrayBuffer[Record]() + + private val batchedWriterThread = startBatchedWriterThread() + + /** + * Write a byte buffer to the log file. This method adds the byteBuffer to a queue and blocks + * until the record is properly written by the parent. + */ + override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + val promise = Promise[WriteAheadLogRecordHandle]() + val putSuccessfully = synchronized { + if (active) { + walWriteQueue.offer(Record(byteBuffer, time, promise)) + true + } else { + false + } + } + if (putSuccessfully) { + Await.result(promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds) + } else { + throw new IllegalStateException("close() was called on BatchedWriteAheadLog before " + + s"write request with time $time could be fulfilled.") + } + } + + /** + * This method is not supported as the resulting ByteBuffer would actually require de-aggregation. + * This method is primarily used in testing, and to ensure that it is not used in production, + * we throw an UnsupportedOperationException. + */ + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + throw new UnsupportedOperationException("read() is not supported for BatchedWriteAheadLog " + + "as the data may require de-aggregation.") + } + + /** + * Read all the existing logs from the log directory. The output of the wrapped WriteAheadLog + * will be de-aggregated. + */ + override def readAll(): JIterator[ByteBuffer] = { + wrappedLog.readAll().asScala.flatMap(deaggregate).asJava + } + + /** + * Delete the log files that are older than the threshold time. + * + * This method is handled by the parent WriteAheadLog. + */ + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wrappedLog.clean(threshTime, waitForCompletion) + } + + + /** + * Stop the batched writer thread, fulfill promises with failures and close the wrapped WAL. + */ + override def close(): Unit = { + logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") + synchronized { + active = false + } + batchedWriterThread.interrupt() + batchedWriterThread.join() + while (!walWriteQueue.isEmpty) { + val Record(_, time, promise) = walWriteQueue.poll() + promise.failure(new IllegalStateException("close() was called on BatchedWriteAheadLog " + + s"before write request with time $time could be fulfilled.")) + } + wrappedLog.close() + } + + /** Start the actual log writer on a separate thread. */ + private def startBatchedWriterThread(): Thread = { + val thread = new Thread(new Runnable { + override def run(): Unit = { + while (active) { + try { + flushRecords() + } catch { + case NonFatal(e) => + logWarning("Encountered exception in Batched Writer Thread.", e) + } + } + logInfo("BatchedWriteAheadLog Writer thread exiting.") + } + }, "BatchedWriteAheadLog Writer") + thread.setDaemon(true) + thread.start() + thread + } + + /** Write all the records in the buffer to the write ahead log. */ + private def flushRecords(): Unit = { + try { + buffer.append(walWriteQueue.take()) + val numBatched = walWriteQueue.drainTo(buffer.asJava) + 1 + logDebug(s"Received $numBatched records from queue") + } catch { + case _: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.") + } + try { + var segment: WriteAheadLogRecordHandle = null + if (buffer.length > 0) { + logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // We take the latest record for the timestamp. Please refer to the class Javadoc for + // detailed explanation + val time = buffer.last.time + segment = wrappedLog.write(aggregate(buffer), time) + } + buffer.foreach(_.promise.success(segment)) + } catch { + case e: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.", e) + buffer.foreach(_.promise.failure(e)) + case NonFatal(e) => + logWarning(s"BatchedWriteAheadLog Writer failed to write $buffer", e) + buffer.foreach(_.promise.failure(e)) + } finally { + buffer.clear() + } + } + + /** Method for querying the queue length. Should only be used in tests. */ + private def getQueueLength(): Int = walWriteQueue.size() +} + +/** Static methods for aggregating and de-aggregating records. */ +private[util] object BatchedWriteAheadLog { + + /** + * Wrapper class for representing the records that we will write to the WriteAheadLog. Coupled + * with the timestamp for the write request of the record, and the promise that will block the + * write request, while a separate thread is actually performing the write. + */ + case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) + + /** Copies the byte array of a ByteBuffer. */ + private def getByteArray(buffer: ByteBuffer): Array[Byte] = { + val byteArray = new Array[Byte](buffer.remaining()) + buffer.get(byteArray) + byteArray + } + + /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ + def aggregate(records: Seq[Record]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( + records.map(record => getByteArray(record.data)).toArray)) + } + + /** + * De-aggregate serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. + * A stream may not have used batching initially, but started using it after a restart. This + * method therefore needs to be backwards compatible. + */ + def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + try { + Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap) + } catch { + case _: ClassCastException => // users may restart a stream with batching enabled + Array(buffer) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index bc3f2486c21fd..72705f1a9c010 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,10 +17,12 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer +import java.util.concurrent.ThreadPoolExecutor import java.util.{Iterator => JIterator} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.collection.parallel.ThreadPoolTaskSupport import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps @@ -57,8 +59,8 @@ private[streaming] class FileBasedWriteAheadLog( private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("") private val threadpoolName = s"WriteAheadLogManager $callerNameTag" - implicit private val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) + private val threadpool = ThreadUtils.newDaemonCachedThreadPool(threadpoolName, 20) + private val executionContext = ExecutionContext.fromExecutorService(threadpool) override protected val logName = s"WriteAheadLogManager $callerNameTag" private var currentLogPath: Option[String] = None @@ -124,13 +126,19 @@ private[streaming] class FileBasedWriteAheadLog( */ def readAll(): JIterator[ByteBuffer] = synchronized { val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath - logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) - - logFilesToRead.iterator.map { file => + logInfo("Reading from the logs:\n" + logFilesToRead.mkString("\n")) + def readFile(file: String): Iterator[ByteBuffer] = { logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) - }.flatten.asJava + } + if (!closeFileAfterWrite) { + logFilesToRead.iterator.map(readFile).flatten.asJava + } else { + // For performance gains, it makes sense to parallelize the recovery if + // closeFileAfterWrite = true + seqToParIterator(threadpool, logFilesToRead, readFile).asJava + } } /** @@ -146,30 +154,33 @@ private[streaming] class FileBasedWriteAheadLog( * asynchronously. */ def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { - val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } + val oldLogFiles = synchronized { + val expiredLogs = pastLogs.filter { _.endTime < threshTime } + pastLogs --= expiredLogs + expiredLogs + } logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") - def deleteFiles() { - oldLogFiles.foreach { logInfo => - try { - val path = new Path(logInfo.path) - val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) - fs.delete(path, true) - synchronized { pastLogs -= logInfo } - logDebug(s"Cleared log file $logInfo") - } catch { - case ex: Exception => - logWarning(s"Error clearing write ahead log file $logInfo", ex) - } + def deleteFile(walInfo: LogInfo): Unit = { + try { + val path = new Path(walInfo.path) + val fs = HdfsUtils.getFileSystemForPath(path, hadoopConf) + fs.delete(path, true) + logDebug(s"Cleared log file $walInfo") + } catch { + case ex: Exception => + logWarning(s"Error clearing write ahead log file $walInfo", ex) } logInfo(s"Cleared log files in $logDirectory older than $threshTime") } - if (!executionContext.isShutdown) { - val f = Future { deleteFiles() } - if (waitForCompletion) { - import scala.concurrent.duration._ - Await.ready(f, 1 second) + oldLogFiles.foreach { logInfo => + if (!executionContext.isShutdown) { + val f = Future { deleteFile(logInfo) }(executionContext) + if (waitForCompletion) { + import scala.concurrent.duration._ + Await.ready(f, 1 second) + } } } } @@ -251,4 +262,23 @@ private[streaming] object FileBasedWriteAheadLog { } }.sortBy { _.startTime } } + + /** + * This creates an iterator from a parallel collection, by keeping at most `n` objects in memory + * at any given time, where `n` is the size of the thread pool. This is crucial for use cases + * where we create `FileBasedWriteAheadLogReader`s during parallel recovery. We don't want to + * open up `k` streams altogether where `k` is the size of the Seq that we want to parallelize. + */ + def seqToParIterator[I, O]( + tpool: ThreadPoolExecutor, + source: Seq[I], + handler: I => Iterator[O]): Iterator[O] = { + val taskSupport = new ThreadPoolTaskSupport(tpool) + val groupSize = tpool.getMaximumPoolSize.max(8) + source.grouped(groupSize).flatMap { group => + val parallelCollection = group.par + parallelCollection.tasksupport = taskSupport + parallelCollection.map(handler) + }.flatten + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala index f7168229ec15a..56d4977da0b51 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala @@ -30,7 +30,7 @@ private[streaming] class FileBasedWriteAheadLogRandomReader(path: String, conf: extends Closeable { private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false + private var closed = (instream == null) // the file may be deleted as we're opening the stream def read(segment: FileBasedWriteAheadLogSegment): ByteBuffer = synchronized { assertOpen() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala index c3bb59f3fef94..a375c0729534b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.streaming.util -import java.io.{Closeable, EOFException} +import java.io.{IOException, Closeable, EOFException} import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration @@ -32,7 +32,7 @@ private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Config extends Iterator[ByteBuffer] with Closeable with Logging { private val instream = HdfsUtils.getInputStream(path, conf) - private var closed = false + private var closed = (instream == null) // the file may be deleted as we're opening the stream private var nextItem: Option[ByteBuffer] = None override def hasNext: Boolean = synchronized { @@ -55,6 +55,19 @@ private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Config logDebug("Error reading next item, EOF reached", e) close() false + case e: IOException => + logWarning("Error while trying to read data. If the file was deleted, " + + "this should be okay.", e) + close() + if (HdfsUtils.checkFileExists(path, conf)) { + // If file exists, this could be a legitimate error + throw e + } else { + // File was deleted. This can occur when the daemon cleanup thread takes time to + // delete the file during recovery. + false + } + case e: Exception => logWarning("Error while trying to read data from HDFS.", e) close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala index f60688f173c44..13a765d035ee8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.streaming.util +import java.io.IOException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ @@ -42,8 +44,19 @@ private[streaming] object HdfsUtils { def getInputStream(path: String, conf: Configuration): FSDataInputStream = { val dfsPath = new Path(path) val dfs = getFileSystemForPath(dfsPath, conf) - val instream = dfs.open(dfsPath) - instream + if (dfs.isFile(dfsPath)) { + try { + dfs.open(dfsPath) + } catch { + case e: IOException => + // If we are really unlucky, the file may be deleted as we're opening the stream. + // This can happen as clean up is performed by daemon threads that may be left over from + // previous runs. + if (!dfs.isFile(dfsPath)) null else throw e + } + } else { + null + } } def checkState(state: Boolean, errorMsg: => String) { @@ -71,4 +84,11 @@ private[streaming] object HdfsUtils { case _ => fs } } + + /** Check if the file exists at the given path. */ + def checkFileExists(path: String, conf: Configuration): Boolean = { + val hdpPath = new Path(path) + val fs = getFileSystemForPath(hdpPath, conf) + fs.isFile(hdpPath) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala new file mode 100644 index 0000000000000..34287c3e00908 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ +import org.apache.spark.util.collection.OpenHashMap + +/** Internal interface for defining the map that keeps track of sessions. */ +private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { + + /** Get the state for a key if it exists */ + def get(key: K): Option[S] + + /** Get all the keys and states whose updated time is older than the given threshold time */ + def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] + + /** Get all the keys and states in this map. */ + def getAll(): Iterator[(K, S, Long)] + + /** Add or update state */ + def put(key: K, state: S, updatedTime: Long): Unit + + /** Remove a key */ + def remove(key: K): Unit + + /** + * Shallow copy `this` map to create a new state map. + * Updates to the new map should not mutate `this` map. + */ + def copy(): StateMap[K, S] + + def toDebugString(): String = toString() +} + +/** Companion object for [[StateMap]], with utility methods */ +private[streaming] object StateMap { + def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + + def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { + val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", + DELTA_CHAIN_LENGTH_THRESHOLD) + new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + } +} + +/** Implementation of StateMap interface representing an empty map */ +private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { + override def put(key: K, session: S, updateTime: Long): Unit = { + throw new NotImplementedError("put() should not be called on an EmptyStateMap") + } + override def get(key: K): Option[S] = None + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty + override def getAll(): Iterator[(K, S, Long)] = Iterator.empty + override def copy(): StateMap[K, S] = this + override def remove(key: K): Unit = { } + override def toDebugString(): String = "" +} + +/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ +private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( + @transient @volatile var parentStateMap: StateMap[K, S], + initialCapacity: Int = 64, + deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + ) extends StateMap[K, S] { self => + + def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + new EmptyStateMap[K, S], + initialCapacity = initialCapacity, + deltaChainThreshold = deltaChainThreshold) + + def this(deltaChainThreshold: Int) = this( + initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + + def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + + @transient @volatile private var deltaMap = + new OpenHashMap[K, StateInfo[S]](initialCapacity) + + /** Get the session data if it exists */ + override def get(key: K): Option[S] = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + if (!stateInfo.deleted) { + Some(stateInfo.data) + } else { + None + } + } else { + parentStateMap.get(key) + } + } + + /** Get all the keys and states whose updated time is older than the give threshold time */ + override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = { + val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) => + !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime + }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Get all the keys and states in this map. */ + override def getAll(): Iterator[(K, S, Long)] = { + + val oldStates = parentStateMap.getAll().filter { case (key, _, _) => + !deltaMap.contains(key) + } + + val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) => + (key, stateInfo.data, stateInfo.updateTime) + } + oldStates ++ updatedStates + } + + /** Add or update state */ + override def put(key: K, state: S, updateTime: Long): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.update(state, updateTime) + } else { + deltaMap.update(key, new StateInfo(state, updateTime)) + } + } + + /** Remove a state */ + override def remove(key: K): Unit = { + val stateInfo = deltaMap(key) + if (stateInfo != null) { + stateInfo.markDeleted() + } else { + val newInfo = new StateInfo[S](deleted = true) + deltaMap.update(key, newInfo) + } + } + + /** + * Shallow copy the map to create a new session store. Updates to the new map + * should not mutate `this` map. + */ + override def copy(): StateMap[K, S] = { + new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold) + } + + /** Whether the delta chain lenght is long enough that it should be compacted */ + def shouldCompact: Boolean = { + deltaChainLength >= deltaChainThreshold + } + + /** Length of the delta chains of this map */ + def deltaChainLength: Int = parentStateMap match { + case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1 + case _ => 0 + } + + /** + * Approximate number of keys in the map. This is an overestimation that is mainly used to + * reserve capacity in a new map at delta compaction time. + */ + def approxSize: Int = deltaMap.size + { + parentStateMap match { + case s: OpenHashMapBasedStateMap[_, _] => s.approxSize + case _ => 0 + } + } + + /** Get all the data of this map as string formatted as a tree based on the delta depth */ + override def toDebugString(): String = { + val tabs = if (deltaChainLength > 0) { + (" " * (deltaChainLength - 1)) + "+--- " + } else "" + parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "") + } + + override def toString(): String = { + s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]" + } + + /** + * Serialize the map data. Besides serialization, this method actually compact the deltas + * (if needed) in a single pass over all the data in the map. + */ + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + + // Write the data in the delta of this state map + outputStream.writeInt(deltaMap.size) + val deltaMapIterator = deltaMap.iterator + var deltaMapCount = 0 + while (deltaMapIterator.hasNext) { + deltaMapCount += 1 + val (key, stateInfo) = deltaMapIterator.next() + outputStream.writeObject(key) + outputStream.writeObject(stateInfo) + } + assert(deltaMapCount == deltaMap.size) + + // Write the data in the parent state map while copying the data into a new parent map for + // compaction (if needed) + val doCompaction = shouldCompact + val newParentSessionStore = if (doCompaction) { + val initCapacity = if (approxSize > 0) approxSize else 64 + new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold) + } else { null } + + val iterOfActiveSessions = parentStateMap.getAll() + + var parentSessionCount = 0 + + // First write the approximate size of the data to be written, so that readObject can + // allocate appropriately sized OpenHashMap. + outputStream.writeInt(approxSize) + + while(iterOfActiveSessions.hasNext) { + parentSessionCount += 1 + + val (key, state, updateTime) = iterOfActiveSessions.next() + outputStream.writeObject(key) + outputStream.writeObject(state) + outputStream.writeLong(updateTime) + + if (doCompaction) { + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + + // Write the final limit marking object with the correct count of records written. + val limiterObj = new LimitMarker(parentSessionCount) + outputStream.writeObject(limiterObj) + if (doCompaction) { + parentStateMap = newParentSessionStore + } + } + + /** Deserialize the map data. */ + private def readObject(inputStream: ObjectInputStream): Unit = { + + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + + // Read the data of the delta + val deltaMapSize = inputStream.readInt() + deltaMap = if (deltaMapSize != 0) { + new OpenHashMap[K, StateInfo[S]](deltaMapSize) + } else { + new OpenHashMap[K, StateInfo[S]](initialCapacity) + } + var deltaMapCount = 0 + while (deltaMapCount < deltaMapSize) { + val key = inputStream.readObject().asInstanceOf[K] + val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]] + deltaMap.update(key, sessionInfo) + deltaMapCount += 1 + } + + + // Read the data of the parent map. Keep reading records, until the limiter is reached + // First read the approximate number of records to expect and allocate properly size + // OpenHashMap + val parentSessionStoreSizeHint = inputStream.readInt() + val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( + initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + + // Read the records until the limit marking object has been reached + var parentSessionLoopDone = false + while(!parentSessionLoopDone) { + val obj = inputStream.readObject() + if (obj.isInstanceOf[LimitMarker]) { + parentSessionLoopDone = true + val expectedCount = obj.asInstanceOf[LimitMarker].num + assert(expectedCount == newParentSessionStore.deltaMap.size) + } else { + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) + } + } + parentStateMap = newParentSessionStore + } +} + +/** + * Companion object of [[OpenHashMapBasedStateMap]] having associated helper + * classes and methods + */ +private[streaming] object OpenHashMapBasedStateMap { + + /** Internal class to represent the state information */ + case class StateInfo[S]( + var data: S = null.asInstanceOf[S], + var updateTime: Long = -1, + var deleted: Boolean = false) { + + def markDeleted(): Unit = { + deleted = true + } + + def update(newData: S, newUpdateTime: Long): Unit = { + data = newData + updateTime = newUpdateTime + deleted = false + } + } + + /** + * Internal class to represent a marker the demarkate the the end of all state data in the + * serialized bytes. + */ + class LimitMarker(val num: Int) extends Serializable + + val DELTA_CHAIN_LENGTH_THRESHOLD = 20 +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 0ea970e61b694..7f9e2c9734970 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -38,6 +38,8 @@ private[streaming] object WriteAheadLogUtils extends Logging { val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + val DRIVER_WAL_BATCHING_CONF_KEY = "spark.streaming.driver.writeAheadLog.allowBatching" + val DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY = "spark.streaming.driver.writeAheadLog.batchingTimeout" val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = "spark.streaming.driver.writeAheadLog.closeFileAfterWrite" @@ -64,6 +66,18 @@ private[streaming] object WriteAheadLogUtils extends Logging { } } + def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = { + isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = true) + } + + /** + * How long we will wait for the wrappedLog in the BatchedWriteAheadLog to write the records + * before we fail the write attempt to unblock receivers. + */ + def getBatchingTimeout(conf: SparkConf): Long = { + conf.getLong(DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY, defaultValue = 5000) + } + def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = { if (isDriver) { conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) @@ -115,7 +129,7 @@ private[streaming] object WriteAheadLogUtils extends Logging { } else { sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) } - classNameOption.map { className => + val wal = classNameOption.map { className => try { instantiateClass( Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) @@ -128,6 +142,11 @@ private[streaming] object WriteAheadLogUtils extends Logging { getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver), shouldCloseFileAfterWrite(sparkConf, isDriver)) } + if (isBatchingEnabled(sparkConf, isDriver)) { + new BatchedWriteAheadLog(wal, sparkConf) + } else { + wal + } } /** Instantiate the class, either using single arg constructor or zero arg constructor */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java new file mode 100644 index 0000000000000..67b2a0703e02b --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.*; + +public class JavaStreamingListenerAPISuite extends JavaStreamingListener { + + @Override + public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStarted) { + JavaReceiverInfo receiverInfo = receiverStarted.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverError(JavaStreamingListenerReceiverError receiverError) { + JavaReceiverInfo receiverInfo = receiverError.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverStopped(JavaStreamingListenerReceiverStopped receiverStopped) { + JavaReceiverInfo receiverInfo = receiverStopped.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.executorId(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onBatchSubmitted(JavaStreamingListenerBatchSubmitted batchSubmitted) { + super.onBatchSubmitted(batchSubmitted); + } + + @Override + public void onBatchStarted(JavaStreamingListenerBatchStarted batchStarted) { + super.onBatchStarted(batchStarted); + } + + @Override + public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { + super.onBatchCompleted(batchCompleted); + } + + @Override + public void onOutputOperationStarted(JavaStreamingListenerOutputOperationStarted outputOperationStarted) { + super.onOutputOperationStarted(outputOperationStarted); + } + + @Override + public void onOutputOperationCompleted(JavaStreamingListenerOutputOperationCompleted outputOperationCompleted) { + super.onOutputOperationCompleted(outputOperationCompleted); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java new file mode 100644 index 0000000000000..eac4cdd14a683 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.util.ManualClock; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.Function4; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaTrackStateDStream; + +public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable { + + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testAPI() { + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + final Function4, State, Optional> + trackStateFunc = + new Function4, State, Optional>() { + + @Override + public Optional call( + Time time, String word, Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + } + }; + + JavaTrackStateDStream stateDstream = + wordsDstream.trackStateByKey( + StateSpec.function(trackStateFunc) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + + final Function2, State, Double> trackStateFunc2 = + new Function2, State, Double>() { + + @Override + public Double call(Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + } + }; + + JavaTrackStateDStream stateDstream2 = + wordsDstream.trackStateByKey( + StateSpec. function(trackStateFunc2) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords2 = stateDstream2.stateSnapshots(); + } + + @Test + public void testBasicFunction() { + List> inputData = Arrays.asList( + Collections.emptyList(), + Arrays.asList("a"), + Arrays.asList("a", "b"), + Arrays.asList("a", "b", "c"), + Arrays.asList("a", "b"), + Arrays.asList("a"), + Collections.emptyList() + ); + + List> outputData = Arrays.asList( + Collections.emptySet(), + Sets.newHashSet(1), + Sets.newHashSet(2, 1), + Sets.newHashSet(3, 2, 1), + Sets.newHashSet(4, 3), + Sets.newHashSet(5), + Collections.emptySet() + ); + + List>> stateData = Arrays.asList( + Collections.>emptySet(), + Sets.newHashSet(new Tuple2("a", 1)), + Sets.newHashSet(new Tuple2("a", 2), new Tuple2("b", 1)), + Sets.newHashSet( + new Tuple2("a", 3), + new Tuple2("b", 2), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 4), + new Tuple2("b", 3), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 5), + new Tuple2("b", 3), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 5), + new Tuple2("b", 3), + new Tuple2("c", 1)) + ); + + Function2, State, Integer> trackStateFunc = + new Function2, State, Integer>() { + + @Override + public Integer call(Optional value, State state) throws Exception { + int sum = value.or(0) + (state.exists() ? state.get() : 0); + state.update(sum); + return sum; + } + }; + testOperation( + inputData, + StateSpec.function(trackStateFunc), + outputData, + stateData); + } + + private void testOperation( + List> input, + StateSpec trackStateSpec, + List> expectedOutputs, + List>> expectedStateSnapshots) { + int numBatches = expectedOutputs.size(); + JavaDStream inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); + JavaTrackStateDStream trackeStateStream = + JavaPairDStream.fromJavaDStream(inputStream.map(new Function>() { + @Override + public Tuple2 call(K x) throws Exception { + return new Tuple2(x, 1); + } + })).trackStateByKey(trackStateSpec); + + final List> collectedOutputs = + Collections.synchronizedList(Lists.>newArrayList()); + trackeStateStream.foreachRDD(new Function, Void>() { + @Override + public Void call(JavaRDD rdd) throws Exception { + collectedOutputs.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + final List>> collectedStateSnapshots = + Collections.synchronizedList(Lists.>>newArrayList()); + trackeStateStream.stateSnapshots().foreachRDD(new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws Exception { + collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + BatchCounter batchCounter = new BatchCounter(ssc.ssc()); + ssc.start(); + ((ManualClock) ssc.ssc().scheduler().clock()) + .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1); + batchCounter.waitUntilBatchesCompleted(numBatches, 10000); + + Assert.assertEquals(expectedOutputs, collectedOutputs); + Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java index 175b8a496b4e5..09b5f8ed03279 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -108,6 +108,7 @@ public void close() { public void testCustomWAL() { SparkConf conf = new SparkConf(); conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName()); + conf.set("spark.streaming.driver.writeAheadLog.allowBatching", "false"); WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); String data1 = "data1"; diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala new file mode 100644 index 0000000000000..0295e059f7bc2 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler._ + +class JavaStreamingListenerWrapperSuite extends SparkFunSuite { + + test("basic") { + val listener = new TestJavaStreamingListener() + val listenerWrapper = new JavaStreamingListenerWrapper(listener) + + val receiverStarted = StreamingListenerReceiverStarted(ReceiverInfo( + streamId = 2, + name = "test", + active = true, + location = "localhost", + executorId = "1" + )) + listenerWrapper.onReceiverStarted(receiverStarted) + assertReceiverInfo(listener.receiverStarted.receiverInfo, receiverStarted.receiverInfo) + + val receiverStopped = StreamingListenerReceiverStopped(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + executorId = "1" + )) + listenerWrapper.onReceiverStopped(receiverStopped) + assertReceiverInfo(listener.receiverStopped.receiverInfo, receiverStopped.receiverInfo) + + val receiverError = StreamingListenerReceiverError(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + executorId = "1", + lastErrorMessage = "failed", + lastError = "failed", + lastErrorTime = System.currentTimeMillis() + )) + listenerWrapper.onReceiverError(receiverError) + assertReceiverInfo(listener.receiverError.receiverInfo, receiverError.receiverInfo) + + val batchSubmitted = StreamingListenerBatchSubmitted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + None, + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = None, + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = None, + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchSubmitted(batchSubmitted) + assertBatchInfo(listener.batchSubmitted.batchInfo, batchSubmitted.batchInfo) + + val batchStarted = StreamingListenerBatchStarted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchStarted(batchStarted) + assertBatchInfo(listener.batchStarted.batchInfo, batchStarted.batchInfo) + + val batchCompleted = StreamingListenerBatchCompleted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + Some(1010L), + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = Some(1010L), + failureReason = None)) + )) + listenerWrapper.onBatchCompleted(batchCompleted) + assertBatchInfo(listener.batchCompleted.batchInfo, batchCompleted.batchInfo) + + val outputOperationStarted = StreamingListenerOutputOperationStarted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None + )) + listenerWrapper.onOutputOperationStarted(outputOperationStarted) + assertOutputOperationInfo(listener.outputOperationStarted.outputOperationInfo, + outputOperationStarted.outputOperationInfo) + + val outputOperationCompleted = StreamingListenerOutputOperationCompleted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None + )) + listenerWrapper.onOutputOperationCompleted(outputOperationCompleted) + assertOutputOperationInfo(listener.outputOperationCompleted.outputOperationInfo, + outputOperationCompleted.outputOperationInfo) + } + + private def assertReceiverInfo( + javaReceiverInfo: JavaReceiverInfo, receiverInfo: ReceiverInfo): Unit = { + assert(javaReceiverInfo.streamId === receiverInfo.streamId) + assert(javaReceiverInfo.name === receiverInfo.name) + assert(javaReceiverInfo.active === receiverInfo.active) + assert(javaReceiverInfo.location === receiverInfo.location) + assert(javaReceiverInfo.executorId === receiverInfo.executorId) + assert(javaReceiverInfo.lastErrorMessage === receiverInfo.lastErrorMessage) + assert(javaReceiverInfo.lastError === receiverInfo.lastError) + assert(javaReceiverInfo.lastErrorTime === receiverInfo.lastErrorTime) + } + + private def assertBatchInfo(javaBatchInfo: JavaBatchInfo, batchInfo: BatchInfo): Unit = { + assert(javaBatchInfo.batchTime === batchInfo.batchTime) + assert(javaBatchInfo.streamIdToInputInfo.size === batchInfo.streamIdToInputInfo.size) + batchInfo.streamIdToInputInfo.foreach { case (streamId, streamInputInfo) => + assertStreamingInfo(javaBatchInfo.streamIdToInputInfo.get(streamId), streamInputInfo) + } + assert(javaBatchInfo.submissionTime === batchInfo.submissionTime) + assert(javaBatchInfo.processingStartTime === batchInfo.processingStartTime.getOrElse(-1)) + assert(javaBatchInfo.processingEndTime === batchInfo.processingEndTime.getOrElse(-1)) + assert(javaBatchInfo.schedulingDelay === batchInfo.schedulingDelay.getOrElse(-1)) + assert(javaBatchInfo.processingDelay === batchInfo.processingDelay.getOrElse(-1)) + assert(javaBatchInfo.totalDelay === batchInfo.totalDelay.getOrElse(-1)) + assert(javaBatchInfo.numRecords === batchInfo.numRecords) + assert(javaBatchInfo.outputOperationInfos.size === batchInfo.outputOperationInfos.size) + batchInfo.outputOperationInfos.foreach { case (outputOperationId, outputOperationInfo) => + assertOutputOperationInfo( + javaBatchInfo.outputOperationInfos.get(outputOperationId), outputOperationInfo) + } + } + + private def assertStreamingInfo( + javaStreamInputInfo: JavaStreamInputInfo, streamInputInfo: StreamInputInfo): Unit = { + assert(javaStreamInputInfo.inputStreamId === streamInputInfo.inputStreamId) + assert(javaStreamInputInfo.numRecords === streamInputInfo.numRecords) + assert(javaStreamInputInfo.metadata === streamInputInfo.metadata.asJava) + assert(javaStreamInputInfo.metadataDescription === streamInputInfo.metadataDescription.orNull) + } + + private def assertOutputOperationInfo( + javaOutputOperationInfo: JavaOutputOperationInfo, + outputOperationInfo: OutputOperationInfo): Unit = { + assert(javaOutputOperationInfo.batchTime === outputOperationInfo.batchTime) + assert(javaOutputOperationInfo.id === outputOperationInfo.id) + assert(javaOutputOperationInfo.name === outputOperationInfo.name) + assert(javaOutputOperationInfo.description === outputOperationInfo.description) + assert(javaOutputOperationInfo.startTime === outputOperationInfo.startTime.getOrElse(-1)) + assert(javaOutputOperationInfo.endTime === outputOperationInfo.endTime.getOrElse(-1)) + assert(javaOutputOperationInfo.failureReason === outputOperationInfo.failureReason.orNull) + } +} + +class TestJavaStreamingListener extends JavaStreamingListener { + + var receiverStarted: JavaStreamingListenerReceiverStarted = null + var receiverError: JavaStreamingListenerReceiverError = null + var receiverStopped: JavaStreamingListenerReceiverStopped = null + var batchSubmitted: JavaStreamingListenerBatchSubmitted = null + var batchStarted: JavaStreamingListenerBatchStarted = null + var batchCompleted: JavaStreamingListenerBatchCompleted = null + var outputOperationStarted: JavaStreamingListenerOutputOperationStarted = null + var outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted = null + + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + this.receiverStarted = receiverStarted + } + + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + this.receiverError = receiverError + } + + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + this.receiverStopped = receiverStopped + } + + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + this.batchSubmitted = batchSubmitted + } + + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + this.batchStarted = batchStarted + } + + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + this.batchCompleted = batchCompleted + } + + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + this.outputOperationStarted = outputOperationStarted + } + + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + this.outputOperationCompleted = outputOperationCompleted + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 8844c9d74b933..bc223e648a417 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.streaming +import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils +import org.apache.spark.util.ManualClock +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} /** * Tests whether scope information is passed from DStream operations to RDDs correctly. @@ -32,7 +35,9 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd private val batchDuration: Duration = Seconds(1) override def beforeAll(): Unit = { - ssc = new StreamingContext(new SparkContext("local", "test"), batchDuration) + val conf = new SparkConf().setMaster("local").setAppName("test") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + ssc = new StreamingContext(new SparkContext(conf), batchDuration) } override def afterAll(): Unit = { @@ -103,6 +108,8 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd test("scoping nested operations") { val inputStream = new DummyInputDStream(ssc) + // countByKeyAndWindow internally uses reduceByKeyAndWindow, but only countByKeyAndWindow + // should appear in scope val countStream = inputStream.countByWindow(Seconds(10), Seconds(1)) countStream.initialize(Time(0)) @@ -137,6 +144,57 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd testStream(countStream) } + test("transform should allow RDD operations to be captured in scopes") { + val inputStream = new DummyInputDStream(ssc) + val transformedStream = inputStream.transform { _.map { _ -> 1}.reduceByKey(_ + _) } + transformedStream.initialize(Time(0)) + + val transformScopeBase = transformedStream.baseScope.map(RDDOperationScope.fromJson) + val transformScope1 = transformedStream.getOrCompute(Time(1000)).get.scope + val transformScope2 = transformedStream.getOrCompute(Time(2000)).get.scope + val transformScope3 = transformedStream.getOrCompute(Time(3000)).get.scope + + // Assert that all children RDDs inherit the DStream operation name correctly + assertDefined(transformScopeBase, transformScope1, transformScope2, transformScope3) + assert(transformScopeBase.get.name === "transform") + assertNestedScopeCorrect(transformScope1.get, 1000) + assertNestedScopeCorrect(transformScope2.get, 2000) + assertNestedScopeCorrect(transformScope3.get, 3000) + + def assertNestedScopeCorrect(rddScope: RDDOperationScope, batchTime: Long): Unit = { + assert(rddScope.name === "reduceByKey") + assert(rddScope.parent.isDefined) + assertScopeCorrect(transformScopeBase.get, rddScope.parent.get, batchTime) + } + } + + test("foreachRDD should allow RDD operations to be captured in scope") { + val inputStream = new DummyInputDStream(ssc) + val generatedRDDs = new ArrayBuffer[RDD[(Int, Int)]] + inputStream.foreachRDD { rdd => + generatedRDDs += rdd.map { _ -> 1}.reduceByKey(_ + _) + } + val batchCounter = new BatchCounter(ssc) + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(3000) + batchCounter.waitUntilBatchesCompleted(3, 10000) + assert(generatedRDDs.size === 3) + + val foreachBaseScope = + ssc.graph.getOutputStreams().head.baseScope.map(RDDOperationScope.fromJson) + assertDefined(foreachBaseScope) + assert(foreachBaseScope.get.name === "foreachRDD") + + val rddScopes = generatedRDDs.map { _.scope } + assertDefined(rddScopes: _*) + rddScopes.zipWithIndex.foreach { case (rddScope, idx) => + assert(rddScope.get.name === "reduceByKey") + assert(rddScope.get.parent.isDefined) + assertScopeCorrect(foreachBaseScope.get, rddScope.get.parent.get, (idx + 1) * 1000) + } + } + /** Assert that the RDD operation scope properties are not set in our SparkContext. */ private def assertPropertiesNotSet(): Unit = { assert(ssc != null) @@ -149,19 +207,12 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd baseScope: RDDOperationScope, rddScope: RDDOperationScope, batchTime: Long): Unit = { - assertScopeCorrect(baseScope.id, baseScope.name, rddScope, batchTime) - } - - /** Assert that the given RDD scope inherits the base name and ID correctly. */ - private def assertScopeCorrect( - baseScopeId: String, - baseScopeName: String, - rddScope: RDDOperationScope, - batchTime: Long): Unit = { + val (baseScopeId, baseScopeName) = (baseScope.id, baseScope.name) val formattedBatchTime = UIUtils.formatBatchTime( batchTime, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) assert(rddScope.id === s"${baseScopeId}_$batchTime") assert(rddScope.name.replaceAll("\\n", " ") === s"$baseScopeName @ $formattedBatchTime") + assert(rddScope.parent.isEmpty) // There should not be any higher scope } /** Assert that all the specified options are defined. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index f793a12843b2f..081f5a1c93e6e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.File +import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -32,7 +33,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.{WriteAheadLogUtils, FileBasedWriteAheadLogReader} +import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -207,6 +208,75 @@ class ReceivedBlockTrackerSuite tracker1.isWriteAheadLogEnabled should be (false) } + test("parallel file deletion in FileBasedWriteAheadLog is robust to deletion error") { + conf.set("spark.streaming.driver.writeAheadLog.rollingIntervalSecs", "1") + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = true) === 1) + + val addBlocks = generateBlockInfos() + val batch1 = addBlocks.slice(0, 1) + val batch2 = addBlocks.slice(1, 3) + val batch3 = addBlocks.slice(3, addBlocks.length) + + assert(getWriteAheadLogFiles().length === 0) + + // list of timestamps for files + val t = Seq.tabulate(5)(i => i * 1000) + + writeEventsManually(getLogFileName(t(0)), Seq(createBatchCleanup(t(0)))) + assert(getWriteAheadLogFiles().length === 1) + + // The goal is to create several log files which should have been cleaned up. + // If we face any issue during recovery, because these old files exist, then we need to make + // deletion more robust rather than a parallelized operation where we fire and forget + val batch1Allocation = createBatchAllocation(t(1), batch1) + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + + writeEventsManually(getLogFileName(t(2)), Seq(createBatchCleanup(t(1)))) + + val batch2Allocation = createBatchAllocation(t(3), batch2) + writeEventsManually(getLogFileName(t(3)), batch2.map(BlockAdditionEvent) :+ batch2Allocation) + + writeEventsManually(getLogFileName(t(4)), batch3.map(BlockAdditionEvent)) + + // We should have 5 different log files as we called `writeEventsManually` with 5 different + // timestamps + assert(getWriteAheadLogFiles().length === 5) + + // Create the tracker to recover from the log files. We're going to ask the tracker to clean + // things up, and then we're going to rewrite that data, and recover using a different tracker. + // They should have identical data no matter what + val tracker = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + + def compareTrackers(base: ReceivedBlockTracker, subject: ReceivedBlockTracker): Unit = { + subject.getBlocksOfBatchAndStream(t(3), streamId) should be( + base.getBlocksOfBatchAndStream(t(3), streamId)) + subject.getBlocksOfBatchAndStream(t(1), streamId) should be( + base.getBlocksOfBatchAndStream(t(1), streamId)) + subject.getBlocksOfBatchAndStream(t(0), streamId) should be(Nil) + } + + // ask the tracker to clean up some old files + tracker.cleanupOldBatches(t(3), waitForCompletion = true) + assert(getWriteAheadLogFiles().length === 3) + + val tracker2 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker2) + + // rewrite first file + writeEventsManually(getLogFileName(t(0)), Seq(createBatchCleanup(t(0)))) + assert(getWriteAheadLogFiles().length === 4) + // make sure trackers are consistent + val tracker3 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker3) + + // rewrite second file + writeEventsManually(getLogFileName(t(1)), batch1.map(BlockAdditionEvent) :+ batch1Allocation) + assert(getWriteAheadLogFiles().length === 5) + // make sure trackers are consistent + val tracker4 = createTracker(recoverFromWriteAheadLog = true, clock = new ManualClock(t(4))) + compareTrackers(tracker, tracker4) + } + /** * Create tracker object with the optional provided clock. Use fake clock if you * want to control time by manually incrementing it to test log clean. @@ -228,11 +298,30 @@ class ReceivedBlockTrackerSuite BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } + /** + * Write received block tracker events to a file manually. + */ + def writeEventsManually(filePath: String, events: Seq[ReceivedBlockTrackerLogEvent]): Unit = { + val writer = HdfsUtils.getOutputStream(filePath, hadoopConf) + events.foreach { event => + val bytes = Utils.serialize(event) + writer.writeInt(bytes.size) + writer.write(bytes) + } + writer.close() + } + /** Get all the data written in the given write ahead log file. */ def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = { getWrittenLogData(Seq(logFile)) } + /** Get the log file name for the given log start time. */ + def getLogFileName(time: Long, rollingIntervalSecs: Int = 1): String = { + checkpointDirectory.toString + File.separator + "receivedBlockMetadata" + + File.separator + s"log-$time-${time + rollingIntervalSecs * 1000}" + } + /** * Get all the data written in the given write ahead log files. By default, it will read all * files in the test log directory. @@ -241,8 +330,13 @@ class ReceivedBlockTrackerSuite : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq - }.map { byteBuffer => - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) + }.flatMap { byteBuffer => + val validBuffer = if (WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) { + Utils.deserialize[Array[Array[Byte]]](byteBuffer.array()).map(ByteBuffer.wrap) + } else { + Array(byteBuffer) + } + validBuffer.map(b => Utils.deserialize[ReceivedBlockTrackerLogEvent](b.array())) }.toList } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala new file mode 100644 index 0000000000000..48d3b41b66cbf --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.collection.{immutable, mutable, Map} +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} +import org.apache.spark.util.Utils + +class StateMapSuite extends SparkFunSuite { + + test("EmptyStateMap") { + val map = new EmptyStateMap[Int, Int] + intercept[scala.NotImplementedError] { + map.put(1, 1, 1) + } + assert(map.get(1) === None) + assert(map.getByTime(10000).isEmpty) + assert(map.getAll().isEmpty) + map.remove(1) // no exception + assert(map.copy().eq(map)) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") { + val map = new OpenHashMapBasedStateMap[Int, Int]() + + map.put(1, 100, 10) + assert(map.get(1) === Some(100)) + assert(map.get(2) === None) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10))) + + map.put(2, 200, 20) + assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20))) + assert(map.getByTime(11).toSet === Set((1, 100, 10))) + assert(map.getByTime(10).toSet === Set.empty) + assert(map.getByTime(9).toSet === Set.empty) + assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20))) + + map.remove(1) + assert(map.get(1) === None) + assert(map.getAll().toSet === Set((2, 200, 20))) + } + + test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with copy") { + val parentMap = new OpenHashMapBasedStateMap[Int, Int]() + parentMap.put(1, 100, 1) + parentMap.put(2, 200, 2) + parentMap.remove(1) + + // Create child map and make changes + val map = parentMap.copy() + assert(map.get(1) === None) + assert(map.get(2) === Some(200)) + assert(map.getByTime(10).toSet === Set((2, 200, 2))) + assert(map.getByTime(2).toSet === Set.empty) + assert(map.getAll().toSet === Set((2, 200, 2))) + + // Add new items + map.put(3, 300, 3) + assert(map.get(3) === Some(300)) + map.put(4, 400, 4) + assert(map.get(4) === Some(400)) + assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3))) + assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Remove items + map.remove(4) + assert(map.get(4) === None) // item added in this map, then removed in this map + map.remove(2) + assert(map.get(2) === None) // item removed in parent map, then added in this map + assert(map.getAll().toSet === Set((3, 300, 3))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + // Update items + map.put(1, 1000, 100) + assert(map.get(1) === Some(1000)) // item removed in parent map, then added in this map + map.put(2, 2000, 200) + assert(map.get(2) === Some(2000)) // item added in parent map, then removed + added in this map + map.put(3, 3000, 300) + assert(map.get(3) === Some(3000)) // item added + updated in this map + map.put(4, 4000, 400) + assert(map.get(4) === Some(4000)) // item removed + updated in this map + + assert(map.getAll().toSet === + Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400))) + assert(parentMap.getAll().toSet === Set((2, 200, 2))) + + map.remove(2) // remove item present in parent map, so that its not visible in child map + + // Create child map and see availability of items + val childMap = map.copy() + assert(childMap.getAll().toSet === map.getAll().toSet) + assert(childMap.get(1) === Some(1000)) // item removed in grandparent, but added in parent map + assert(childMap.get(2) === None) // item added in grandparent, but removed in parent map + assert(childMap.get(3) === Some(3000)) // item added and updated in parent map + + childMap.put(2, 20000, 200) + assert(childMap.get(2) === Some(20000)) // item map + } + + test("OpenHashMapBasedStateMap - serializing and deserializing") { + val map1 = new OpenHashMapBasedStateMap[Int, Int]() + map1.put(1, 100, 1) + map1.put(2, 200, 2) + + val map2 = map1.copy() + map2.put(3, 300, 3) + map2.put(4, 400, 4) + + val map3 = map2.copy() + map3.put(3, 600, 3) + map3.remove(2) + + // Do not test compaction + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + + val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( + Utils.serialize(map3), Thread.currentThread().getContextClassLoader) + assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") + } + + test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { + val targetDeltaLength = 10 + val deltaChainThreshold = 5 + + var map = new OpenHashMapBasedStateMap[Int, Int]( + deltaChainThreshold = deltaChainThreshold) + + // Make large delta chain with length more than deltaChainThreshold + for(i <- 1 to targetDeltaLength) { + map.put(Random.nextInt(), Random.nextInt(), 1) + map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + } + assert(map.deltaChainLength > deltaChainThreshold) + assert(map.shouldCompact === true) + + val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assert(deser_map.deltaChainLength < deltaChainThreshold) + assert(deser_map.shouldCompact === false) + assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") + } + + test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { + /* + * This tests the map using all permutations of sequences operations, across multiple map + * copies as well as between copies. It is to ensure complete coverage, though it is + * kind of hard to debug this. It is set up as follows. + * + * - For any key, there can be 2 types of update ops on a state map - put or remove + * + * - These operations are done on a test map in "sets". After each set, the map is "copied" + * to create a new map, and the next set of operations are done on the new one. This tests + * whether the map data persistes correctly across copies. + * + * - Within each set, there are a number of operations to test whether the map correctly + * updates and removes data without affecting the parent state map. + * + * - Overall this creates (numSets * numOpsPerSet) operations, each of which that can 2 types + * of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence + * of operations, which we will test with different keys. + * + * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. This means that + * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys. + * _______________________________________________ + * | | Set1 | Set2 | + * | |-----------------|-----------------| + * | | Op1 Op2 |c| Op3 Op4 | + * |---------|----------------|o|----------------| + * | key 0 | put put |p| put put | + * | key 1 | put put |y| put rem | + * | key 2 | put put | | rem put | + * | key 3 | put put |t| rem rem | + * | key 4 | put rem |h| put put | + * | key 5 | put rem |e| put rem | + * | key 6 | put rem | | rem put | + * | key 7 | put rem |s| rem rem | + * | key 8 | rem put |t| put put | + * | key 9 | rem put |a| put rem | + * | key 10 | rem put |t| rem put | + * | key 11 | rem put |e| rem rem | + * | key 12 | rem rem | | put put | + * | key 13 | rem rem |m| put rem | + * | key 14 | rem rem |a| rem put | + * | key 15 | rem rem |p| rem rem | + * |_________|________________|_|________________| + */ + + val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value + val numSets = 3 + val numOpsPerSet = 3 // to test seq of ops like update -> remove -> update in same set + val numTotalOps = numOpsPerSet * numSets + val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops + + val refMap = new mutable.HashMap[Int, (Int, Long)]() + var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null + + var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]() + var prevSetStateMap: StateMap[Int, Int] = null + + var time = 1L + + for (setId <- 0 until numSets) { + for (opInSetId <- 0 until numOpsPerSet) { + val opId = setId * numOpsPerSet + opInSetId + for (keyId <- 0 until numKeys) { + time += 1 + // Find the operation type that needs to be done + // This is similar to finding the nth bit value of a binary number + // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2 + val opCode = + (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps + opCode match { + case 0 => + val value = Random.nextInt() + stateMap.put(keyId, value, time) + refMap.put(keyId, (value, time)) + case 1 => + stateMap.remove(keyId) + refMap.remove(keyId) + } + } + + // Test whether the current state map after all key updates is correct + assertMap(stateMap, refMap, time, "State map does not match reference map") + + // Test whether the previous map before copy has not changed + if (prevSetStateMap != null && prevSetRefMap != null) { + assertMap(prevSetStateMap, prevSetRefMap, time, + "Parent state map somehow got modified, does not match corresponding reference map") + } + } + + // Copy the map and remember the previous maps for future tests + prevSetStateMap = stateMap + prevSetRefMap = refMap.toMap + stateMap = stateMap.copy() + + // Assert that the copied map has the same data + assertMap(stateMap, prevSetRefMap, time, + "State map does not match reference map after copying") + } + assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") + } + + // Assert whether all the data and operations on a state map matches that of a reference state map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: StateMap[Int, Int], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.getAll().map { _._1 }) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId)) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + assert(mapToTest.getByTime(t).toSet === refMapToTestWith.getByTime(t).toSet) + } + } + } + + // Assert whether all the data and operations on a state map matches that of a reference map + private def assertMap( + mapToTest: StateMap[Int, Int], + refMapToTestWith: Map[Int, (Int, Long)], + time: Long, + msg: String): Unit = { + withClue(msg) { + // Assert all the data is same as the reference map + assert(mapToTest.getAll().toSet === + refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet) + + // Assert that get on every key returns the right value + for (keyId <- refMapToTestWith.keys) { + assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 }) + } + + // Assert that every time threshold returns the correct data + for (t <- 0L to (time + 1)) { + val expectedRecords = + refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, x._2._1, x._2._2) } + assert(mapToTest.getByTime(t).toSet === expectedRecords.toSet) + } + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 0d58a7b54412f..a45c92d9c7bc8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -98,7 +98,7 @@ class TestOutputStream[T: ClassTag]( ) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) { + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) @@ -122,7 +122,7 @@ class TestOutputStreamWithPartitions[T: ClassTag]( extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.glom().collect().map(_.toSeq) output += collected - }) { + }, false) { // This is to clear the output buffer every it is read from a checkpoint @throws(classOf[IOException]) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala new file mode 100644 index 0000000000000..e3072b4442840 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.File + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.reflect.ClassTag + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { + + private var sc: SparkContext = null + private var ssc: StreamingContext = null + private var checkpointDir: File = null + private val batchDuration = Seconds(1) + + before { + StreamingContext.getActive().foreach { + _.stop(stopSparkContext = false) + } + checkpointDir = Utils.createTempDir("checkpoint") + + ssc = new StreamingContext(sc, batchDuration) + ssc.checkpoint(checkpointDir.toString) + } + + after { + StreamingContext.getActive().foreach { + _.stop(stopSparkContext = false) + } + } + + override def beforeAll(): Unit = { + val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite") + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + sc = new SparkContext(conf) + } + + test("state - get, exists, update, remove, ") { + var state: StateImpl[Int] = null + + def testState( + expectedData: Option[Int], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false, + shouldBeTimingOut: Boolean = false + ): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get() === expectedData.get) + assert(state.getOption() === expectedData) + assert(state.getOption.getOrElse(-1) === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get() + } + assert(state.getOption() === None) + assert(state.getOption.getOrElse(-1) === -1) + } + + assert(state.isTimingOut() === shouldBeTimingOut) + if (shouldBeTimingOut) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + + assert(state.isUpdated() === shouldBeUpdated) + + assert(state.isRemoved() === shouldBeRemoved) + if (shouldBeRemoved) { + intercept[IllegalArgumentException] { + state.remove() + } + intercept[IllegalArgumentException] { + state.update(-1) + } + } + } + + state = new StateImpl[Int]() + testState(None) + + state.wrap(None) + testState(None) + + state.wrap(Some(1)) + testState(Some(1)) + + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state = new StateImpl[Int]() + state.update(2) + testState(Some(2), shouldBeUpdated = true) + + state.remove() + testState(None, shouldBeRemoved = true) + + state.wrapTiminoutState(3) + testState(Some(3), shouldBeTimingOut = true) + } + + test("trackStateByKey - basic operations with simple API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(1), + Seq(2, 1), + Seq(3, 2, 1), + Seq(4, 3), + Seq(5), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, and updated count is returned + val trackStateFunc = (value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + sum + } + + testOperation[String, Int, Int]( + inputData, StateSpec.function(trackStateFunc), outputData, stateData) + } + + test("trackStateByKey - basic operations with advanced API") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq("aa"), + Seq("aa", "bb"), + Seq("aa", "bb", "cc"), + Seq("aa", "bb"), + Seq("aa"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + // state maintains running count, key string doubled and returned + val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + state.update(sum) + Some(key * 2) + } + + testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + } + + test("trackStateByKey - type inferencing and class tags") { + + // Simple track state function with value as Int, state as Double and emitted type as Double + val simpleFunc = (value: Option[Int], state: State[Double]) => { + 0L + } + + // Advanced track state function with key as String, value as Int, state as Double and + // emitted type as Double + val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => { + Some(0L) + } + + def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = { + val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]] + assert(dstreamImpl.keyClass === classOf[String]) + assert(dstreamImpl.valueClass === classOf[Int]) + assert(dstreamImpl.stateClass === classOf[Double]) + assert(dstreamImpl.emittedClass === classOf[Long]) + } + + val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) + + // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types + val simpleFunctionStateStream1 = inputStream.trackStateByKey( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(simpleFunctionStateStream1) + + // Separately defining StateSpec with simple function requires explicitly specifying types + val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc) + val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec) + testTypes(simpleFunctionStateStream2) + + // Separately defining StateSpec with advanced function implicitly gets the types + val advFuncSpec1 = StateSpec.function(advancedFunc) + val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1) + testTypes(advFunctionStateStream1) + + // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types + val advFunctionStateStream2 = inputStream.trackStateByKey( + StateSpec.function(simpleFunc).numPartitions(1)) + testTypes(advFunctionStateStream2) + + // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types + val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc) + val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2) + testTypes(advFunctionStateStream3) + } + + test("trackStateByKey - states as emitted records") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3)), + Seq(("a", 5)), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + Some(output) + } + + testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData) + } + + test("trackStateByKey - initial states, with nothing emitted") { + + val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)) + + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = Seq.fill(inputData.size)(Seq.empty[Int]) + + val stateData = + Seq( + Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)), + Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)), + Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)), + Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)), + Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)) + ) + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + val sum = value.getOrElse(0) + state.getOption.getOrElse(0) + val output = (key, sum) + state.update(sum) + None.asInstanceOf[Option[Int]] + } + + val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState)) + testOperation(inputData, trackStateSpec, outputData, stateData) + } + + test("trackStateByKey - state removing") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), // a will be removed + Seq("a", "b", "c"), // b will be removed + Seq("a", "b", "c"), // a and c will be removed + Seq("a", "b"), // b will be removed + Seq("a"), // a will be removed + Seq() + ) + + // States that were removed + val outputData = + Seq( + Seq(), + Seq(), + Seq("a"), + Seq("b"), + Seq("a", "c"), + Seq("b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("b", 1)), + Seq(("a", 1), ("c", 1)), + Seq(("b", 1)), + Seq(("a", 1)), + Seq(), + Seq() + ) + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (state.exists) { + state.remove() + Some(key) + } else { + state.update(value.get) + None + } + } + + testOperation( + inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData) + } + + test("trackStateByKey - state timing out") { + val inputData = + Seq( + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq(), // c will time out + Seq(), // b will time out + Seq("a") // a will not time out + ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active + + val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => { + if (value.isDefined) { + state.update(1) + } + if (state.isTimingOut) { + Some(key) + } else { + None + } + } + + val (collectedOutputs, collectedStateSnapshots) = getOperationOutput( + inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20) + + // b and c should be emitted once each, when they were marked as expired + assert(collectedOutputs.flatten.sorted === Seq("b", "c")) + + // States for a, b, c should be defined at one point of time + assert(collectedStateSnapshots.exists { + _.toSet == Set(("a", 1), ("b", 1), ("c", 1)) + }) + + // Finally state should be defined only for a + assert(collectedStateSnapshots.last.toSet === Set(("a", 1))) + } + + + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + trackStateSpec: StateSpec[K, Int, S, T], + expectedOutputs: Seq[Seq[T]], + expectedStateSnapshots: Seq[Seq[(K, S)]] + ): Unit = { + require(expectedOutputs.size == expectedStateSnapshots.size) + + val (collectedOutputs, collectedStateSnapshots) = + getOperationOutput(input, trackStateSpec, expectedOutputs.size) + assert(expectedOutputs, collectedOutputs, "outputs") + assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots") + } + + private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag]( + input: Seq[Seq[K]], + trackStateSpec: StateSpec[K, Int, S, T], + numBatches: Int + ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { + + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) + val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs) + val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]] + val stateSnapshotStream = new TestOutputStream( + trackeStateStream.stateSnapshots(), collectedStateSnapshots) + outputStream.register() + stateSnapshotStream.register() + + val batchCounter = new BatchCounter(ssc) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds * numBatches) + + batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + (collectedOutputs, collectedStateSnapshots) + } + + private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) { + val debugString = "\nExpected:\n" + expected.mkString("\n") + + "\nCollected:\n" + collected.mkString("\n") + assert(expected.size === collected.size, + s"number of collected $typ (${collected.size}) different from expected (${expected.size})" + + debugString) + expected.zip(collected).foreach { case (c, e) => + assert(c.toSet === e.toSet, + s"collected $typ is different from expected $debugString" + ) + } + } +} + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala new file mode 100644 index 0000000000000..19ef5a14f8ab4 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.rdd + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap +import org.apache.spark.streaming.{Time, State} +import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} + +class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var sc: SparkContext = null + + override def beforeAll(): Unit = { + sc = new SparkContext( + new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite")) + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + + test("creation from pair RDD") { + val data = Seq((1, "1"), (2, "2"), (3, "3")) + val partitioner = new HashPartitioner(10) + val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int]( + sc.parallelize(data), partitioner, Time(123)) + assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty) + assert(rdd.partitions.size === partitioner.numPartitions) + + assert(rdd.partitioner === Some(partitioner)) + } + + test("updating state and generating emitted data in TrackStateRecord") { + + val initialTime = 1000L + val updatedTime = 2000L + val thresholdTime = 1500L + @volatile var functionCalled = false + + /** + * Assert that applying given data on a prior record generates correct updated record, with + * correct state map and emitted data + */ + def assertRecordUpdate( + initStates: Iterable[Int], + data: Iterable[String], + expectedStates: Iterable[(Int, Long)], + timeoutThreshold: Option[Long] = None, + removeTimedoutData: Boolean = false, + expectedOutput: Iterable[Int] = None, + expectedTimingOutStates: Iterable[Int] = None, + expectedRemovedStates: Iterable[Int] = None + ): Unit = { + val initialStateMap = new OpenHashMapBasedStateMap[String, Int]() + initStates.foreach { s => initialStateMap.put("key", s, initialTime) } + functionCalled = false + val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) + val dataIterator = data.map { v => ("key", v) }.iterator + val removedStates = new ArrayBuffer[Int] + val timingOutStates = new ArrayBuffer[Int] + /** + * Tracking function that updates/removes state based on instructions in the data, and + * return state (when instructed or when state is timing out). + */ + def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = { + functionCalled = true + + assert(t.milliseconds === updatedTime, "tracking func called with wrong time") + + data match { + case Some("noop") => + None + case Some("get-state") => + Some(state.getOption().getOrElse(-1)) + case Some("update-state") => + if (state.exists) state.update(state.get + 1) else state.update(0) + None + case Some("remove-state") => + removedStates += state.get() + state.remove() + None + case None => + assert(state.isTimingOut() === true, "State is not timing out when data = None") + timingOutStates += state.get() + None + case _ => + fail("Unexpected test data") + } + } + + val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int]( + Some(record), dataIterator, testFunc, + Time(updatedTime), timeoutThreshold, removeTimedoutData) + + val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) } + assert(updatedStateData.toSet === expectedStates.toSet, + "states do not match after updating the TrackStateRecord") + + assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet, + "emitted data do not match after updating the TrackStateRecord") + + assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " + + "match those that were expected to do so while updating the TrackStateRecord") + + assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " + + "match those that were expected to do so while updating the TrackStateRecord") + + } + + // No data, no state should be changed, function should not be called, + assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil) + assert(functionCalled === false) + assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime))) + assert(functionCalled === false) + + // Data present, function should be called irrespective of whether state exists + assertRecordUpdate(initStates = Seq(0), data = Seq("noop"), + expectedStates = Seq((0, initialTime))) + assert(functionCalled === true) + assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None) + assert(functionCalled === true) + + // Function called with right state data + assertRecordUpdate(initStates = None, data = Seq("get-state"), + expectedStates = None, expectedOutput = Seq(-1)) + assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"), + expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123)) + + // Update state and timestamp, when timeout not present + assertRecordUpdate(initStates = Nil, data = Seq("update-state"), + expectedStates = Seq((0, updatedTime))) + assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"), + expectedStates = Seq((1, updatedTime))) + + // Remove state + assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"), + expectedStates = Nil, expectedRemovedStates = Seq(345)) + + // State strictly older than timeout threshold should be timed out + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime), removeTimedoutData = true, + expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil) + + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Seq(123)) + + // State should not be timed out after it has received data + assertRecordUpdate(initStates = Seq(123), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil) + assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) + + } + + test("states generated by TrackStateRDD") { + val initStates = Seq(("k1", 0), ("k2", 0)) + val initTime = 123 + val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet + val partitioner = new HashPartitioner(2) + val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int]( + sc.parallelize(initStates), partitioner, Time(initTime)).persist() + assertRDD(initStateRDD, initStateWthTime, Set.empty) + + val updateTime = 345 + + /** + * Test that the test state RDD, when operated with new data, + * creates a new state RDD with expected states + */ + def testStateUpdates( + testStateRDD: TrackStateRDD[String, Int, Int, Int], + testData: Seq[(String, Int)], + expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = { + + // Persist the test TrackStateRDD so that its not recomputed while doing the next operation. + // This is to make sure that we only track which state keys are being touched in the next op. + testStateRDD.persist().count() + + // To track which keys are being touched + TrackStateRDDSuite.touchedStateKeys.clear() + + val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => { + + // Track the key that has been touched + TrackStateRDDSuite.touchedStateKeys += key + + // If the data is 0, do not do anything with the state + // else if the data is 1, increment the state if it exists, or set new state to 0 + // else if the data is 2, remove the state if it exists + data match { + case Some(1) => + if (state.exists()) { state.update(state.get + 1) } + else state.update(0) + case Some(2) => + state.remove() + case _ => + } + None.asInstanceOf[Option[Int]] // Do not return anything, not being tested + } + val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get) + + // Assert that the new state RDD has expected state data + val newStateRDD = assertOperation( + testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty) + + // Assert that the function was called only for the keys present in the data + assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size, + "More number of keys are being touched than that is expected") + assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys, + "Keys not in the data are being touched unexpectedly") + + // Assert that the test RDD's data has not changed + assertRDD(initStateRDD, initStateWthTime, Set.empty) + newStateRDD + } + + // Test no-op, no state should change + testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state + testStateUpdates( + initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state + testStateUpdates( + initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state + + // Test creation of new state + val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime))) + + val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0 + Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime))) + + // Test updating of state + val rdd3 = testStateUpdates( + initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1 + Set(("k1", 1, updateTime), ("k2", 0, initTime))) + + val rdd4 = testStateUpdates(rdd3, + Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime))) + + val rdd5 = testStateUpdates( + rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2 + Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime))) + + // Test removing of state + val rdd6 = testStateUpdates( // should remove k1's state + initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime))) + + val rdd7 = testStateUpdates( // should remove k2's state + rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) + + val rdd8 = testStateUpdates( // should remove k3's state + rdd7, Seq(("k3", 2)), Set()) + } + + /** Assert whether the `trackStateByKey` operation generates expected results */ + private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + testStateRDD: TrackStateRDD[K, V, S, T], + newDataRDD: RDD[(K, V)], + trackStateFunc: (Time, K, Option[V], State[S]) => Option[T], + currentTime: Long, + expectedStates: Set[(K, S, Int)], + expectedEmittedRecords: Set[T], + doFullScan: Boolean = false + ): TrackStateRDD[K, V, S, T] = { + + val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) { + newDataRDD.partitionBy(testStateRDD.partitioner.get) + } else { + newDataRDD + } + + val newStateRDD = new TrackStateRDD[K, V, S, T]( + testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None) + if (doFullScan) newStateRDD.setFullScan() + + // Persist to make sure that it gets computed only once and we can track precisely how many + // state keys the computing touched + newStateRDD.persist().count() + assertRDD(newStateRDD, expectedStates, expectedEmittedRecords) + newStateRDD + } + + /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */ + private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( + trackStateRDD: TrackStateRDD[K, V, S, T], + expectedStates: Set[(K, S, Int)], + expectedEmittedRecords: Set[T]): Unit = { + val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet + val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet + assert(states === expectedStates, + "states after track state operation were not as expected") + assert(emittedRecords === expectedEmittedRecords, + "emitted records after track state operation were not as expected") + } +} + +object TrackStateRDDSuite { + private val touchedStateKeys = new ArrayBuffer[String]() +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index 2f11b255f1104..92ad9fe52b777 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.receiver import scala.collection.mutable +import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index af4718b4eb705..34cd7435569e1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -130,20 +130,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost", "0") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost", "1") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost", "2") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 93ae41a3d2ecd..7f80d6ecdbbb5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,31 +18,45 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util +import java.util.{Iterator => JIterator} +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{RejectedExecutionException, TimeUnit, CountDownLatch, ThreadPoolExecutor} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.Matchers.{eq => meq} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.BeforeAndAfter +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter} +import org.scalatest.mock.MockitoSugar -import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{CompletionIterator, ThreadUtils, ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkFunSuite} -class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { +/** Common tests for WriteAheadLogs that we would like to test with different configurations. */ +abstract class CommonWriteAheadLogTests( + allowBatching: Boolean, + closeFileAfterWrite: Boolean, + testTag: String = "") + extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - val hadoopConf = new Configuration() - var tempDir: File = null - var testDir: String = null - var testFile: String = null - var writeAheadLog: FileBasedWriteAheadLog = null + protected val hadoopConf = new Configuration() + protected var tempDir: File = null + protected var testDir: String = null + protected var testFile: String = null + protected var writeAheadLog: WriteAheadLog = null + protected def testPrefix = if (testTag != "") testTag + " - " else testTag before { tempDir = Utils.createTempDir() @@ -58,47 +72,208 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogUtils - log selection and creation") { - val logDir = Utils.createTempDir().getAbsolutePath() + test(testPrefix + "read all logs") { + // Write data manually for testing reading through WriteAheadLog + val writtenData = (1 to 10).map { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + }.flatten + + val logDirectoryPath = new Path(testDir) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + assert(fileSystem.exists(logDirectoryPath) === true) + + // Read data using manager and verify + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === writtenData) + } + + test(testPrefix + "write logs") { + // Write data with rotation using WriteAheadLog class + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite = closeFileAfterWrite, + allowBatching = allowBatching) + + // Read data manually to verify the written data + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val writtenData = readAndDeserializeDataManually(logFiles, allowBatching) + assert(writtenData === dataToWrite) + } + + test(testPrefix + "read all logs after write") { + // Write data with manager, recover with new manager and verify + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, allowBatching) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(dataToWrite === readData) + } - def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + test(testPrefix + "clean old logs") { + logCleanUpTest(waitForCompletion = false) + } + + test(testPrefix + "clean old logs synchronously") { + logCleanUpTest(waitForCompletion = true) + } + + private def logCleanUpTest(waitForCompletion: Boolean): Unit = { + // Write data with manager, recover with new manager and verify + val manualClock = new ManualClock + val dataToWrite = generateRandomData() + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, + allowBatching, manualClock, closeLog = false) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) + + if (waitForCompletion) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } else { + eventually(Eventually.timeout(1 second), interval(10 milliseconds)) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } } + } - def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + test(testPrefix + "handling file errors while reading rotating logs") { + // Generate a set of log files + val manualClock = new ManualClock + val dataToWrite1 = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite1, closeFileAfterWrite, allowBatching, + manualClock) + val logFiles1 = getLogFilesInDirectory(testDir) + assert(logFiles1.size > 1) + + + // Recover old files and generate a second set of log files + val dataToWrite2 = generateRandomData() + manualClock.advance(100000) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching , + manualClock) + val logFiles2 = getLogFilesInDirectory(testDir) + assert(logFiles2.size > logFiles1.size) + + // Read the files and verify that all the written data can be read + val readData1 = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + + // Corrupt the first set of files so that they are basically unreadable + logFiles1.foreach { f => + val raf = new FileOutputStream(f, true).getChannel() + raf.truncate(1) + raf.close() } - val emptyConf = new SparkConf() // no log configuration - assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) - assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) - - // Verify setting driver WAL class - val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[MockWriteAheadLog0](conf1) - assertReceiverLogClass[FileBasedWriteAheadLog](conf1) - - // Verify setting receiver WAL class - val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) - assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) - - // Verify setting receiver WAL class with 1-arg constructor - val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog1].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) - - // Verify failure setting receiver WAL class with 2-arg constructor - intercept[SparkException] { - val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog2].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + // Verify that the corrupted files do not prevent reading of the second set of data + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === dataToWrite2) + } + + test(testPrefix + "do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile, allowBatching) + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + if (allowBatching) { + intercept[UnsupportedOperationException](wal.read(writtenSegment.head)) + } else { + wal.read(writtenSegment.head) + } + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + } + + test(testPrefix + "parallel recovery not enabled if closeFileAfterWrite = false") { + // write some data + val writtenData = (1 to 10).map { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + }.flatten + + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + // create iterator but don't materialize it + val readData = wal.readAll().asScala.map(byteBufferToString) + wal.close() + if (closeFileAfterWrite) { + // the threadpool is shutdown by the wal.close call above, therefore we shouldn't be able + // to materialize the iterator with parallel recovery + intercept[RejectedExecutionException](readData.toArray) + } else { + assert(readData.toSeq === writtenData) + } + } +} + +class FileBasedWriteAheadLogSuite + extends CommonWriteAheadLogTests(false, false, "FileBasedWriteAheadLog") { + + import WriteAheadLogSuite._ + + test("FileBasedWriteAheadLog - seqToParIterator") { + /* + If the setting `closeFileAfterWrite` is enabled, we start generating a very large number of + files. This causes recovery to take a very long time. In order to make it quicker, we + parallelized the reading of these files. This test makes sure that we limit the number of + open files to the size of the number of threads in our thread pool rather than the size of + the list of files. + */ + val numThreads = 8 + val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool") + class GetMaxCounter { + private val value = new AtomicInteger() + @volatile private var max: Int = 0 + def increment(): Unit = synchronized { + val atInstant = value.incrementAndGet() + if (atInstant > max) max = atInstant + } + def decrement(): Unit = synchronized { value.decrementAndGet() } + def get(): Int = synchronized { value.get() } + def getMax(): Int = synchronized { max } + } + try { + // If Jenkins is slow, we may not have a chance to run many threads simultaneously. Having + // a latch will make sure that all the threads can be launched altogether. + val latch = new CountDownLatch(1) + val testSeq = 1 to 1000 + val counter = new GetMaxCounter() + def handle(value: Int): Iterator[Int] = { + new CompletionIterator[Int, Iterator[Int]](Iterator(value)) { + counter.increment() + // block so that other threads also launch + latch.await(10, TimeUnit.SECONDS) + override def completion() { counter.decrement() } + } + } + @volatile var collected: Seq[Int] = Nil + val t = new Thread() { + override def run() { + // run the calculation on a separate thread so that we can release the latch + val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq, handle) + collected = iterator.toSeq + } + } + t.start() + eventually(Eventually.timeout(10.seconds)) { + // make sure we are doing a parallel computation! + assert(counter.getMax() > 1) + } + latch.countDown() + t.join(10000) + assert(collected === testSeq) + // make sure we didn't open too many Iterators + assert(counter.getMax() <= numThreads) + } finally { + tpool.shutdownNow() } } @@ -122,7 +297,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() - writeDataManually(writtenData, testFile) + writeDataManually(writtenData, testFile, allowBatching = false) val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) @@ -163,10 +338,30 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { assert(readDataUsingReader(testFile) === (dataToWrite.dropRight(1))) } + test("FileBasedWriteAheadLogReader - handles errors when file doesn't exist") { + // Write data manually for testing the sequential reader + val dataToWrite = generateRandomData() + writeDataUsingWriter(testFile, dataToWrite) + val tFile = new File(testFile) + assert(tFile.exists()) + // Verify the data can be read and is same as the one correctly written + assert(readDataUsingReader(testFile) === dataToWrite) + + tFile.delete() + assert(!tFile.exists()) + + val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) + assert(!reader.hasNext) + reader.close() + + // Verify that no exception is thrown if file doesn't exist + assert(readDataUsingReader(testFile) === Nil) + } + test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() - val segments = writeDataManually(writtenData, testFile) + val segments = writeDataManually(writtenData, testFile, allowBatching = false) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten @@ -190,163 +385,210 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { } reader.close() } +} - test("FileBasedWriteAheadLog - write rotating logs") { - // Write data with rotation using WriteAheadLog class - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) +abstract class CloseFileAfterWriteTests(allowBatching: Boolean, testTag: String) + extends CommonWriteAheadLogTests(allowBatching, closeFileAfterWrite = true, testTag) { - // Read data manually to verify the written data - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val writtenData = logFiles.flatMap { file => readDataManually(file)} - assert(writtenData === dataToWrite) - } - - test("FileBasedWriteAheadLog - close after write flag") { + import WriteAheadLogSuite._ + test(testPrefix + "close after write flag") { // Write data with rotation using WriteAheadLog class val numFiles = 3 val dataToWrite = Seq.tabulate(numFiles)(_.toString) // total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100, - closeFileAfterWrite = true) + closeFileAfterWrite = true, allowBatching = allowBatching) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size === numFiles) - val writtenData = logFiles.flatMap { file => readDataManually(file)} + val writtenData: Seq[String] = readAndDeserializeDataManually(logFiles, allowBatching) assert(writtenData === dataToWrite) } +} - test("FileBasedWriteAheadLog - read rotating logs") { - // Write data manually for testing reading through WriteAheadLog - val writtenData = (1 to 10).map { i => - val data = generateRandomData() - val file = testDir + s"/log-$i-$i" - writeDataManually(data, file) - data - }.flatten +class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = false, "FileBasedWriteAheadLog") - val logDirectoryPath = new Path(testDir) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - assert(fileSystem.exists(logDirectoryPath) === true) +class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( + allowBatching = true, + closeFileAfterWrite = false, + "BatchedWriteAheadLog") + with MockitoSugar + with BeforeAndAfterEach + with Eventually + with PrivateMethodTester { - // Read data using manager and verify - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === writtenData) - } + import BatchedWriteAheadLog._ + import WriteAheadLogSuite._ - test("FileBasedWriteAheadLog - recover past logs when creating new manager") { - // Write data with manager, recover with new manager and verify - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val readData = readDataUsingWriteAheadLog(testDir) - assert(dataToWrite === readData) - } + private var wal: WriteAheadLog = _ + private var walHandle: WriteAheadLogRecordHandle = _ + private var walBatchingThreadPool: ThreadPoolExecutor = _ + private var walBatchingExecutionContext: ExecutionContextExecutorService = _ + private val sparkConf = new SparkConf() - test("FileBasedWriteAheadLog - clean old logs") { - logCleanUpTest(waitForCompletion = false) - } + private val queueLength = PrivateMethod[Int]('getQueueLength) - test("FileBasedWriteAheadLog - clean old logs synchronously") { - logCleanUpTest(waitForCompletion = true) + override def beforeEach(): Unit = { + wal = mock[WriteAheadLog] + walHandle = mock[WriteAheadLogRecordHandle] + walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool") + walBatchingExecutionContext = ExecutionContext.fromExecutorService(walBatchingThreadPool) } - private def logCleanUpTest(waitForCompletion: Boolean): Unit = { - // Write data with manager, recover with new manager and verify - val manualClock = new ManualClock - val dataToWrite = generateRandomData() - writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - - writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) - - if (waitForCompletion) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } else { - eventually(timeout(1 second), interval(10 milliseconds)) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } + override def afterEach(): Unit = { + if (walBatchingExecutionContext != null) { + walBatchingExecutionContext.shutdownNow() } } - test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { - // Generate a set of log files - val manualClock = new ManualClock - val dataToWrite1 = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) - val logFiles1 = getLogFilesInDirectory(testDir) - assert(logFiles1.size > 1) + test("BatchedWriteAheadLog - serializing and deserializing batched records") { + val events = Seq( + BlockAdditionEvent(ReceivedBlockInfo(0, None, None, null)), + BatchAllocationEvent(null, null), + BatchCleanupEvent(Nil) + ) + val buffers = events.map(e => Record(ByteBuffer.wrap(Utils.serialize(e)), 0L, null)) + val batched = BatchedWriteAheadLog.aggregate(buffers) + val deaggregate = BatchedWriteAheadLog.deaggregate(batched).map(buffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](buffer.array())) - // Recover old files and generate a second set of log files - val dataToWrite2 = generateRandomData() - manualClock.advance(100000) - writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) - val logFiles2 = getLogFilesInDirectory(testDir) - assert(logFiles2.size > logFiles1.size) + assert(deaggregate.toSeq === events) + } - // Read the files and verify that all the written data can be read - val readData1 = readDataUsingWriteAheadLog(testDir) - assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + test("BatchedWriteAheadLog - failures in wrappedLog get bubbled up") { + when(wal.write(any[ByteBuffer], anyLong)).thenThrow(new RuntimeException("Hello!")) + // the BatchedWriteAheadLog should bubble up any exceptions that may have happened during writes + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - // Corrupt the first set of files so that they are basically unreadable - logFiles1.foreach { f => - val raf = new FileOutputStream(f, true).getChannel() - raf.truncate(1) - raf.close() + intercept[RuntimeException] { + val buffer = mock[ByteBuffer] + batchedWal.write(buffer, 2L) } + } - // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === dataToWrite2) + // we make the write requests in separate threads so that we don't block the test thread + private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + val p = Promise[Unit]() + p.completeWith(Future { + val v = wal.write(event, time) + assert(v === walHandle) + }(walBatchingExecutionContext)) + p } - test("FileBasedWriteAheadLog - do not create directories or files unless write") { - val nonexistentTempPath = File.createTempFile("test", "") - nonexistentTempPath.delete() - assert(!nonexistentTempPath.exists()) + test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) + + val event1 = "hello" + val event2 = "world" + val event3 = "this" + val event4 = "is" + val event5 = "doge" + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + writeAsync(batchedWal, event2, 5L) + writeAsync(batchedWal, event3, 8L) + writeAsync(batchedWal, event4, 12L) + writeAsync(batchedWal, event5, 10L) + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 5) + assert(batchedWal.invokePrivate(queueLength()) === 4) + } + blockingWal.allowWrite() - val writtenSegment = writeDataManually(generateRandomData(), testFile) - val wal = new FileBasedWriteAheadLog(new SparkConf(), tempDir.getAbsolutePath, - new Configuration(), 1, 1, closeFileAfterWrite = false) - assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") - wal.read(writtenSegment.head) - assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + val buffer1 = wrapArrayArrayByte(Array(event1)) + val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) + + eventually(timeout(1 second)) { + assert(batchedWal.invokePrivate(queueLength()) === 0) + verify(wal, times(1)).write(meq(buffer1), meq(3L)) + // the file name should be the timestamp of the last record, as events should be naturally + // in order of timestamp, and we need the last element. + verify(wal, times(1)).write(meq(buffer2), meq(10L)) + } } -} -object WriteAheadLogSuite { + test("BatchedWriteAheadLog - shutdown properly") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + batchedWal.close() + verify(wal, times(1)).close() - class MockWriteAheadLog0() extends WriteAheadLog { - override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } - override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } - override def readAll(): util.Iterator[ByteBuffer] = { null } - override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } - override def close(): Unit = { } + intercept[IllegalStateException](batchedWal.write(mock[ByteBuffer], 12L)) } - class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + test("BatchedWriteAheadLog - fail everything in queue during shutdown") { + val blockingWal = new BlockingWriteAheadLog(wal, walHandle) + val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf) - class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() + val event1 = "hello" + val event2 = "world" + val event3 = "this" + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + val promise1 = writeAsync(batchedWal, event1, 3L) + eventually(timeout(1 second)) { + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 0) + } + // rest of the records will be batched while it takes time for 3 to get written + val promise2 = writeAsync(batchedWal, event2, 5L) + val promise3 = writeAsync(batchedWal, event3, 8L) + + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 3) + assert(blockingWal.isBlocked) + assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being written + } + + val writePromises = Seq(promise1, promise2, promise3) + + batchedWal.close() + eventually(timeout(1 second)) { + assert(writePromises.forall(_.isCompleted)) + assert(writePromises.forall(_.future.value.get.isFailure)) // all should have failed + } + } +} + +class BatchedWriteAheadLogWithCloseFileAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = true, "BatchedWriteAheadLog") + +object WriteAheadLogSuite { private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + def writeDataManually( + data: Seq[String], + file: String, + allowBatching: Boolean): Seq[FileBasedWriteAheadLogSegment] = { val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) - data.foreach { item => + def writeToStream(bytes: Array[Byte]): Unit = { val offset = writer.getPos - val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } + if (allowBatching) { + writeToStream(wrapArrayArrayByte(data.toArray[String]).array()) + } else { + data.foreach { item => + writeToStream(Utils.serialize(item)) + } + } writer.close() segments } @@ -356,8 +598,7 @@ object WriteAheadLogSuite { */ def writeDataUsingWriter( filePath: String, - data: Seq[String] - ): Seq[FileBasedWriteAheadLogSegment] = { + data: Seq[String]): Seq[FileBasedWriteAheadLogSegment] = { val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) @@ -370,13 +611,13 @@ object WriteAheadLogSuite { def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], + closeFileAfterWrite: Boolean, + allowBatching: Boolean, manualClock: ManualClock = new ManualClock, closeLog: Boolean = true, - clockAdvanceTime: Int = 500, - closeFileAfterWrite: Boolean = false): FileBasedWriteAheadLog = { + clockAdvanceTime: Int = 500): WriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite) + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => @@ -406,16 +647,16 @@ object WriteAheadLogSuite { } /** Read all the data from a log file directly and return the list of byte buffers. */ - def readDataManually(file: String): Seq[String] = { + def readDataManually[T](file: String): Seq[T] = { val reader = HdfsUtils.getInputStream(file, hadoopConf) - val buffer = new ArrayBuffer[String] + val buffer = new ArrayBuffer[T] try { while (true) { // Read till EOF is thrown val length = reader.readInt() val bytes = new Array[Byte](length) reader.read(bytes) - buffer += Utils.deserialize[String](bytes) + buffer += Utils.deserialize[T](bytes) } } catch { case ex: EOFException => @@ -434,15 +675,17 @@ object WriteAheadLogSuite { } /** Read all the data in the log file in a directory using the WriteAheadLog class. */ - def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite = false) - val data = wal.readAll().asScala.map(byteBufferToString).toSeq + def readDataUsingWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): Seq[String] = { + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) + val data = wal.readAll().asScala.map(byteBufferToString).toArray wal.close() data } - /** Get the log files in a direction */ + /** Get the log files in a directory. */ def getLogFilesInDirectory(directory: String): Seq[String] = { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) @@ -458,10 +701,31 @@ object WriteAheadLogSuite { } } + def createWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): WriteAheadLog = { + val sparkConf = new SparkConf + val wal = new FileBasedWriteAheadLog(sparkConf, logDirectory, hadoopConf, 1, 1, + closeFileAfterWrite) + if (allowBatching) new BatchedWriteAheadLog(wal, sparkConf) else wal + } + def generateRandomData(): Seq[String] = { (1 to 100).map { _.toString } } + def readAndDeserializeDataManually(logFiles: Seq[String], allowBatching: Boolean): Seq[String] = { + if (allowBatching) { + logFiles.flatMap { file => + val data = readDataManually[Array[Array[Byte]]](file) + data.flatMap(byteArray => byteArray.map(Utils.deserialize[String])) + } + } else { + logFiles.flatMap { file => readDataManually[String](file)} + } + } + implicit def stringToByteBuffer(str: String): ByteBuffer = { ByteBuffer.wrap(Utils.serialize(str)) } @@ -469,4 +733,41 @@ object WriteAheadLogSuite { implicit def byteBufferToString(byteBuffer: ByteBuffer): String = { Utils.deserialize[String](byteBuffer.array) } + + def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) + } + + /** + * A wrapper WriteAheadLog that blocks the write function to allow batching with the + * BatchedWriteAheadLog. + */ + class BlockingWriteAheadLog( + wal: WriteAheadLog, + handle: WriteAheadLogRecordHandle) extends WriteAheadLog { + @volatile private var isWriteCalled: Boolean = false + @volatile private var blockWrite: Boolean = true + + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + isWriteCalled = true + eventually(Eventually.timeout(2 second)) { + assert(!blockWrite) + } + wal.write(record, time) + isWriteCalled = false + handle + } + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = wal.read(segment) + override def readAll(): JIterator[ByteBuffer] = wal.readAll() + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wal.clean(threshTime, waitForCompletion) + } + override def close(): Unit = wal.close() + + def allowWrite(): Unit = { + blockWrite = false + } + + def isBlocked: Boolean = isWriteCalled + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala new file mode 100644 index 0000000000000..bfc5b0cf60fb1 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils + +class WriteAheadLogUtilsSuite extends SparkFunSuite { + import WriteAheadLogUtilsSuite._ + + private val logDir = Utils.createTempDir().getAbsolutePath() + private val hadoopConf = new Configuration() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag]( + conf: SparkConf, + isBatched: Boolean = false): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + if (isBatched) { + assert(log.isInstanceOf[BatchedWriteAheadLog]) + val parentLog = log.asInstanceOf[BatchedWriteAheadLog].wrappedLog + assert(parentLog.getClass === implicitly[ClassTag[T]].runtimeClass) + } else { + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + } + log + } + + def assertReceiverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + test("log selection and creation") { + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("wrap WriteAheadLog in BatchedWriteAheadLog when batching is enabled") { + def getBatchedSparkConf: SparkConf = + new SparkConf().set("spark.streaming.driver.writeAheadLog.allowBatching", "true") + + val justBatchingConf = getBatchedSparkConf + assertDriverLogClass[FileBasedWriteAheadLog](justBatchingConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](justBatchingConf) + + // Verify setting driver WAL class + val driverWALConf = getBatchedSparkConf.set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify receivers are not wrapped + val receiverWALConf = getBatchedSparkConf.set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + } + + test("batching is enabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = true)) + // batching is not valid for receiver WALs + assert(!WriteAheadLogUtils.isBatchingEnabled(conf, isDriver = false)) + } + + test("closeFileAfterWrite is disabled by default in WriteAheadLog") { + val conf = new SparkConf() + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = true)) + assert(!WriteAheadLogUtils.shouldCloseFileAfterWrite(conf, isDriver = false)) + } +} + +object WriteAheadLogUtilsSuite { + + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() +} diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/tags/src/main/java/org/apache/spark/tags/DockerTest.java new file mode 100644 index 0000000000000..0fecf3b8f979a --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/DockerTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface DockerTest { } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4d9e777cb4134..7e39c3ea56af3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} @@ -96,6 +96,10 @@ private[yarn] class YarnAllocator( // was lost. private val pendingLossReasonRequests = new HashMap[String, mutable.Buffer[RpcCallContext]] + // Maintain loss reasons for already released executors, it will be added when executor loss + // reason is got from AM-RM call, and be removed after querying this loss reason. + private val releasedExecutorLossReasons = new HashMap[String, ExecutorLossReason] + // Keep track of which container is running which executor to remove the executors later // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] @@ -202,8 +206,7 @@ private[yarn] class YarnAllocator( */ def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { - val container = executorIdToContainer.remove(executorId).get - containerIdToExecutorId.remove(container.getId) + val container = executorIdToContainer.get(executorId).get internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -514,9 +517,18 @@ private[yarn] class YarnAllocator( containerIdToExecutorId.remove(containerId).foreach { eid => executorIdToContainer.remove(eid) - pendingLossReasonRequests.remove(eid).foreach { pendingRequests => - // Notify application of executor loss reasons so it can decide whether it should abort - pendingRequests.foreach(_.reply(exitReason)) + pendingLossReasonRequests.remove(eid) match { + case Some(pendingRequests) => + // Notify application of executor loss reasons so it can decide whether it should abort + pendingRequests.foreach(_.reply(exitReason)) + + case None => + // We cannot find executor for pending reasons. This is because completed container + // is processed before querying pending result. We should store it for later query. + // This is usually happened when explicitly killing a container, the result will be + // returned in one AM-RM communication. So query RPC will be later than this completed + // container process. + releasedExecutorLossReasons.put(eid, exitReason) } if (!alreadyReleased) { // The executor could have gone away (like no route to host, node failure, etc) @@ -538,8 +550,14 @@ private[yarn] class YarnAllocator( if (executorIdToContainer.contains(eid)) { pendingLossReasonRequests .getOrElseUpdate(eid, new ArrayBuffer[RpcCallContext]) += context + } else if (releasedExecutorLossReasons.contains(eid)) { + // Executor is already released explicitly before getting the loss reason, so directly send + // the pre-stored lost reason + context.reply(releasedExecutorLossReasons.remove(eid).get) } else { logWarning(s"Tried to get the loss reason for non-existent executor $eid") + context.sendFailure( + new SparkException(s"Fail to find loss reason for non-existent executor $eid")) } } diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala index aa46ec5100f0e..94bf579dc8247 100644 --- a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -19,7 +19,6 @@ package org.apache.spark.network.shuffle import java.io.{IOException, File} import java.util.concurrent.ConcurrentMap -import com.google.common.annotations.VisibleForTesting import org.apache.hadoop.yarn.api.records.ApplicationId import org.fusesource.leveldbjni.JniDBFactory import org.iq80.leveldb.{DB, Options}