Skip to content

Commit

Permalink
Add Spark raise_error function (#10110)
Browse files Browse the repository at this point in the history
Summary:
This PR adds Spark `raise_error` function. It is re-used in Spark `assert_true`
function, so after this PR `assert_true` can also be supported.

Spark document: https://spark.apache.org/docs/latest/api/sql/#raise_error
Spark code: https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala#L85C1-L92C1
Issue: apache/incubator-gluten#5991

Pull Request resolved: #10110

Reviewed By: kevinwilfong

Differential Revision: D59252851

Pulled By: bikramSingh91

fbshipit-source-id: d41347bed8265825bb415e8604179b38ecc89bc8
  • Loading branch information
gaoyangxiaozhu authored and facebook-github-bot committed Jul 3, 2024
1 parent 63cceca commit 4e39b06
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 0 deletions.
5 changes: 5 additions & 0 deletions velox/docs/functions/spark/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ Miscellaneous Functions
The function relies on partition IDs, which are provided by the framework
via the configuration 'spark.partition_id'.

.. spark:function:: raise_error(message)
Throws a user error with the specified ``message``.
If ``message`` is NULL, throws a user error with empty message.

.. spark:function:: spark_partition_id() -> integer
Returns the current partition id.
Expand Down
33 changes: 33 additions & 0 deletions velox/functions/sparksql/RaiseError.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/functions/Macros.h"
namespace facebook::velox::functions::sparksql {

template <typename T>
struct RaiseErrorFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE Status
callNullable(out_type<UnknownValue>& result, const arg_type<Varchar>* input) {
if (input) {
return Status::UserError("{}", *input);
}
return Status::UserError();
}
};
} // namespace facebook::velox::functions::sparksql
4 changes: 4 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/MonotonicallyIncreasingId.h"
#include "velox/functions/sparksql/RaiseError.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
Expand Down Expand Up @@ -461,6 +462,9 @@ void registerFunctions(const std::string& prefix) {
Array<Array<Generic<T1>>>>({prefix + "flatten"});

registerFunction<SoundexFunction, Varchar, Varchar>({prefix + "soundex"});

registerFunction<RaiseErrorFunction, UnknownValue, Varchar>(
{prefix + "raise_error"});
}

} // namespace sparksql
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_executable(
MapTest.cpp
MightContainTest.cpp
MonotonicallyIncreasingIdTest.cpp
RaiseErrorTest.cpp
RandTest.cpp
RegexFunctionsTest.cpp
SizeTest.cpp
Expand Down
36 changes: 36 additions & 0 deletions velox/functions/sparksql/tests/RaiseErrorTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

namespace facebook::velox::functions::sparksql::test {
namespace {

class RaiseErrorTest : public SparkFunctionBaseTest {
protected:
void raiseError(const std::optional<std::string>& message) {
evaluateOnce<UnknownValue>("raise_error(c0)", message);
}
};

TEST_F(RaiseErrorTest, basic) {
VELOX_ASSERT_USER_THROW(raiseError(""), "");
VELOX_ASSERT_USER_THROW(raiseError("0 > 1 is not true"), "0 > 1 is not true");
VELOX_ASSERT_USER_THROW(raiseError(std::nullopt), "");
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 4e39b06

Please sign in to comment.