diff --git a/velox/docs/functions/spark/misc.rst b/velox/docs/functions/spark/misc.rst index 362c76a3c20b..0238d8018adc 100644 --- a/velox/docs/functions/spark/misc.rst +++ b/velox/docs/functions/spark/misc.rst @@ -13,6 +13,11 @@ Miscellaneous Functions The function relies on partition IDs, which are provided by the framework via the configuration 'spark.partition_id'. +.. spark:function:: raise_error(message) + + Throws a user error with the specified ``message``. + If ``message`` is NULL, throws a user error with empty message. + .. spark:function:: spark_partition_id() -> integer Returns the current partition id. diff --git a/velox/functions/sparksql/RaiseError.h b/velox/functions/sparksql/RaiseError.h new file mode 100644 index 000000000000..9b7e65b22584 --- /dev/null +++ b/velox/functions/sparksql/RaiseError.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/functions/Macros.h" +namespace facebook::velox::functions::sparksql { + +template +struct RaiseErrorFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE Status + callNullable(out_type& result, const arg_type* input) { + if (input) { + return Status::UserError("{}", *input); + } + return Status::UserError(); + } +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 1978891e6923..f974718cdb1d 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -38,6 +38,7 @@ #include "velox/functions/sparksql/LeastGreatest.h" #include "velox/functions/sparksql/MightContain.h" #include "velox/functions/sparksql/MonotonicallyIncreasingId.h" +#include "velox/functions/sparksql/RaiseError.h" #include "velox/functions/sparksql/RegexFunctions.h" #include "velox/functions/sparksql/RegisterArithmetic.h" #include "velox/functions/sparksql/RegisterCompare.h" @@ -461,6 +462,9 @@ void registerFunctions(const std::string& prefix) { Array>>>({prefix + "flatten"}); registerFunction({prefix + "soundex"}); + + registerFunction( + {prefix + "raise_error"}); } } // namespace sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index af4b4c3f66da..0adacad86f1f 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -38,6 +38,7 @@ add_executable( MapTest.cpp MightContainTest.cpp MonotonicallyIncreasingIdTest.cpp + RaiseErrorTest.cpp RandTest.cpp RegexFunctionsTest.cpp SizeTest.cpp diff --git a/velox/functions/sparksql/tests/RaiseErrorTest.cpp b/velox/functions/sparksql/tests/RaiseErrorTest.cpp new file mode 100644 index 000000000000..716428a8d962 --- /dev/null +++ b/velox/functions/sparksql/tests/RaiseErrorTest.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class RaiseErrorTest : public SparkFunctionBaseTest { + protected: + void raiseError(const std::optional& message) { + evaluateOnce("raise_error(c0)", message); + } +}; + +TEST_F(RaiseErrorTest, basic) { + VELOX_ASSERT_USER_THROW(raiseError(""), ""); + VELOX_ASSERT_USER_THROW(raiseError("0 > 1 is not true"), "0 > 1 is not true"); + VELOX_ASSERT_USER_THROW(raiseError(std::nullopt), ""); +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test