Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-17658][SPARKR] read.df/write.df API taking path optionally in SparkR #15231

Closed
wants to merge 13 commits into from
20 changes: 15 additions & 5 deletions R/pkg/R/DataFrame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2608,7 +2608,7 @@ setMethod("except",
#' @param ... additional argument(s) passed to the method.
#'
#' @family SparkDataFrame functions
#' @aliases write.df,SparkDataFrame,character-method
#' @aliases write.df,SparkDataFrame-method
#' @rdname write.df
#' @name write.df
#' @export
Expand All @@ -2622,21 +2622,31 @@ setMethod("except",
#' }
#' @note write.df since 1.4.0
setMethod("write.df",
signature(df = "SparkDataFrame", path = "character"),
function(df, path, source = NULL, mode = "error", ...) {
signature(df = "SparkDataFrame"),
function(df, path = NULL, source = NULL, mode = "error", ...) {
if (!is.null(path) && !is.character(path)) {
stop("path should be charactor, NULL or omitted.")
}
if (!is.null(source) && !is.character(source)) {
stop("source should be character, NULL or omitted. It is the datasource specified ",
"in 'spark.sql.sources.default' configuration by default.")
}
if (!is.character(mode)) {
stop("mode should be charactor or omitted. It is 'error' by default.")
}
if (is.null(source)) {
source <- getDefaultSqlSource()
}
jmode <- convertToJSaveMode(mode)
options <- varargsToEnv(...)
if (!is.null(path)) {
options[["path"]] <- path
options[["path"]] <- path
}
write <- callJMethod(df@sdf, "write")
write <- callJMethod(write, "format", source)
write <- callJMethod(write, "mode", jmode)
write <- callJMethod(write, "options", options)
write <- callJMethod(write, "save", path)
write <- handledCallJMethod(write, "save")
})

#' @rdname write.df
Expand Down
19 changes: 13 additions & 6 deletions R/pkg/R/SQLContext.R
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,13 @@ dropTempView <- function(viewName) {
#' @method read.df default
#' @note read.df since 1.4.0
read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) {
if (!is.null(path) && !is.character(path)) {
stop("path should be charactor, NULL or omitted.")
}
if (!is.null(source) && !is.character(source)) {
stop("source should be character, NULL or omitted. It is the datasource specified ",
"in 'spark.sql.sources.default' configuration by default.")
}
sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
Expand All @@ -784,16 +791,16 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string
}
if (!is.null(schema)) {
stopifnot(class(schema) == "structType")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source,
schema$jobj, options)
sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession,
source, schema$jobj, options)
} else {
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
"loadDF", sparkSession, source, options)
sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession,
source, options)
}
dataFrame(sdf)
}

read.df <- function(x, ...) {
read.df <- function(x = NULL, ...) {
dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...)
}

Expand All @@ -805,7 +812,7 @@ loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
read.df(path, source, schema, ...)
}

loadDF <- function(x, ...) {
loadDF <- function(x = NULL, ...) {
dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...)
}

Expand Down
4 changes: 2 additions & 2 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") })

#' @rdname write.df
#' @export
setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) {
setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) {
standardGeneric("write.df")
})

Expand Down Expand Up @@ -732,7 +732,7 @@ setGeneric("withColumnRenamed",

#' @rdname write.df
#' @export
setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") })

#' @rdname randomSplit
#' @export
Expand Down
52 changes: 52 additions & 0 deletions R/pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,58 @@ isSparkRShell <- function() {
grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
}

# Works identically with `callJStatic(...)` but throws a pretty formatted exception.
handledCallJStatic <- function(cls, method, ...) {
result <- tryCatch(callJStatic(cls, method, ...),
error = function(e) {
captureJVMException(e, method)
})
result
}

# Works identically with `callJMethod(...)` but throws a pretty formatted exception.
handledCallJMethod <- function(obj, method, ...) {
result <- tryCatch(callJMethod(obj, method, ...),
error = function(e) {
captureJVMException(e, method)
})
result
}

captureJVMException <- function(e, method) {
rawmsg <- as.character(e)
if (any(grep("^Error in .*?: ", rawmsg))) {
# If the exception message starts with "Error in ...", this is possibly
# "Error in invokeJava(...)". Here, it replaces the characters to
# `paste("Error in", method, ":")` in order to identify which function
# was called in JVM side.
stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very minor nit: you could probably replace the double pass with grep above and strsplit with just the result from strsplit

rmsg <- paste("Error in", method, ":")
stacktrace <- paste(rmsg[1], stacktrace[2])
} else {
# Otherwise, do not convert the error message just in case.
stacktrace <- rawmsg
}

if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there cases where the IllegalArgument should be checked on the R side first to avoid the exception in the first place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! @felixcheung I will address all other comments above. However, for this one, I was thinking hard but it seems not easy because we won't know if given data source is valid or not in R side first.

I might be able to do this only for internal data sources or known databricks datasources such as "redshift" or "xml" like.. creating a map for our internal data sources and then checking a path is given or not. However, I am not sure if it is a good idea to manage another list for datasources.

Copy link
Member

@felixcheung felixcheung Sep 26, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I don't think we should couple the R code to the underlining data source implementations, and was not suggesting that :)

I guess I'm saying there are still many (other) cases where the parameters are unchecked and would be good to see if this check to convert JVM IllegalArgumentException is sufficient or more checks should be added to the R side.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Yeap. This might be about best effort thing. I think I tried (if I am right) all combinations of parameters mssing/wrong in the APIs. One exceptional case for both APIs is, they throw an exception, ClassCastException when the extra options are wrongly typed, which I think we should check within R side and this will be handled in #15239
I might better open another PR for validating parameters across SparkR if you think it is okay.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, thanks - generally I'd prefer having parameter checks in R; though in this case I think we need balance the added code complicity and reduced usability (by checking more, it might fail where it didn't before).

so I'm not 100% sure we should add parameter checks all across the board.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, I do understand and will investigate it with keeping this in mind :)

msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(paste0(rmsg, "illegal argument - ", first), call. = FALSE)
} else if (any(grep("org.apache.spark.sql.AnalysisException: ", stacktrace))) {
msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]]
# Extract "Error in ..." message.
rmsg <- msg[1]
# Extract the first message of JVM exception.
first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
stop(paste0(rmsg, "analysis error - ", first), call. = FALSE)
} else {
stop(stacktrace, call. = FALSE)
}
}

# rbind a list of rows with raw (binary) columns
#
# @param inputData a list of rows, with each row a list
Expand Down
35 changes: 35 additions & 0 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2544,6 +2544,41 @@ test_that("Spark version from SparkSession", {
expect_equal(ver, version)
})

test_that("Call DataFrameWriter.save() API in Java without path and check argument types", {
df <- read.df(jsonPath, "json")
# This tests if the exception is thrown from JVM not from SparkR side.
# It makes sure that we can omit path argument in write.df API and then it calls
# DataFrameWriter.save() without path.
expect_error(write.df(df, source = "csv"),
"Error in save : illegal argument - 'path' is not specified")

# Arguments checking in R side.
expect_error(write.df(df, "data.tmp", source = c(1, 2)),
paste("source should be character, NULL or omitted. It is the datasource specified",
"in 'spark.sql.sources.default' configuration by default."))
expect_error(write.df(df, path = c(3)),
"path should be charactor, NULL or omitted.")
expect_error(write.df(df, mode = TRUE),
"mode should be charactor or omitted. It is 'error' by default.")
})

test_that("Call DataFrameWriter.load() API in Java without path and check argument types", {
# This tests if the exception is thrown from JVM not from SparkR side.
# It makes sure that we can omit path argument in read.df API and then it calls
# DataFrameWriter.load() without path.
expect_error(read.df(source = "json"),
paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .",
"It must be specified manually"))
expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist")

# Arguments checking in R side.
expect_error(read.df(path = c(3)),
"path should be charactor, NULL or omitted.")
expect_error(read.df(jsonPath, source = c(1, 2)),
paste("source should be character, NULL or omitted. It is the datasource specified",
"in 'spark.sql.sources.default' configuration by default."))
})

unlink(parquetPath)
unlink(orcPath)
unlink(jsonPath)
Expand Down
10 changes: 10 additions & 0 deletions R/pkg/inst/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ test_that("convertToJSaveMode", {
'mode should be one of "append", "overwrite", "error", "ignore"') #nolint
})

test_that("captureJVMException", {
method <- "getSQLDataType"
expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method,
Copy link
Member

@felixcheung felixcheung Oct 5, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's change this test to handledCallJStatic too?

"unknown"),
error = function(e) {
captureJVMException(e, method)
}),
"Error in getSQLDataType : illegal argument - Invalid type unknown")
})

test_that("hashCode", {
expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA)
})
Expand Down