Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

case when improvement: avoid copy_if_else #2079

Merged
merged 10 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ set(CUDFJNI_INCLUDE_DIRS
add_library(
spark_rapids_jni SHARED
src/BloomFilterJni.cpp
src/CaseWhenJni.cpp
src/CastStringJni.cpp
src/DateTimeRebaseJni.cpp
src/DecimalUtilsJni.cpp
Expand All @@ -196,6 +197,7 @@ add_library(
src/SparkResourceAdaptorJni.cpp
src/ZOrderJni.cpp
src/bloom_filter.cu
src/case_when.cu
src/cast_decimal_to_string.cu
src/format_float.cu
src/cast_float_to_string.cu
Expand Down
35 changes: 35 additions & 0 deletions src/main/cpp/src/CaseWhenJni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "case_when.hpp"
#include "cudf_jni_apis.hpp"

extern "C" {

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CaseWhen_selectFirstTrueIndex(
JNIEnv* env, jclass, jlongArray bool_cols)
{
JNI_NULL_CHECK(env, bool_cols, "array of column handles is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_bool_columns(env, bool_cols);
auto bool_column_views = n_cudf_bool_columns.get_dereferenced();
return cudf::jni::release_as_jlong(
spark_rapids_jni::select_first_true_index(cudf::table_view(bool_column_views)));
}
CATCH_STD(env, 0);
}
}
102 changes: 102 additions & 0 deletions src/main/cpp/src/case_when.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "case_when.hpp"

#include <cudf/column/column_factories.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/table/table_device_view.cuh>
#include <cudf/types.hpp>

#include <thrust/transform.h>

namespace spark_rapids_jni {
namespace detail {
namespace {

/**
* Select the column index for the first true in bool columns for the specified row
*/
struct select_first_true_fn {
// bool columns stores the results of executing `when` expressions
cudf::table_device_view const d_table;

/**
* The number of bool columns is the size of case when branches.
* Note: reuturned index may be out of bound, valid bound is [0, col_num)
* When returning col_num index, it means final result is NULL value or ELSE value.
*
* e.g.:
* CASE WHEN 'a' THEN 'A' END
* The number of bool columns is 1
* The number of scalars is 1
* Max index is 1 which means using NULL(all when exprs are false).
* CASE WHEN 'a' THEN 'A' ELSE '_' END
* The number of bool columns is 1
* The number of scalars is 2
* Max index is also 1 which means using else value '_'
*/
__device__ cudf::size_type operator()(std::size_t row_idx) const
{
auto col_num = d_table.num_columns();
for (auto col_idx = 0; col_idx < col_num; col_idx++) {
auto const& col = d_table.column(col_idx);
if (!col.is_null(row_idx) && col.element<bool>(row_idx)) {
// Predicate is true and not null
return col_idx;
}
}
return col_num;
}
};

} // anonymous namespace

std::unique_ptr<cudf::column> select_first_true_index(cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
// checks
auto const num_columns = when_bool_columns.num_columns();
CUDF_EXPECTS(num_columns > 0, "At least one column must be specified");
auto const row_count = when_bool_columns.num_rows();
if (row_count == 0) { // empty begets empty
return cudf::make_empty_column(cudf::type_id::INT32);
}
// make output column
auto ret = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, row_count, cudf::mask_state::UNALLOCATED, stream, mr);

// select first true index
auto const d_table_ptr = cudf::table_device_view::create(when_bool_columns, stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(row_count),
ret->mutable_view().begin<cudf::size_type>(),
select_first_true_fn{*d_table_ptr});
return ret;
}

} // namespace detail

std::unique_ptr<cudf::column> select_first_true_index(cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return detail::select_first_true_index(when_bool_columns, stream, mr);
}

} // namespace spark_rapids_jni
53 changes: 53 additions & 0 deletions src/main/cpp/src/case_when.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/table/table_view.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <memory>

namespace spark_rapids_jni {

/**
*
* Select the column index for the first true in bool columns.
* For the row does not contain true, use end index(number of columns).
*
* e.g.:
* column 0 in table: true, false, false, false
* column 1 in table: false, true, false, false
* column 2 in table: false, false, true, false
*
* 1st row is: true, flase, false; first true index is 0
* 2nd row is: false, true, false; first true index is 1
* 3rd row is: false, flase, true; first true index is 2
* 4th row is: false, false, false; do not find true, set index to the end index 3
*
* output column: 0, 1, 2, 3
* In the `case when` context, here 3 index means using NULL value.
*
*/
std::unique_ptr<cudf::column> select_first_true_index(
cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

} // namespace spark_rapids_jni
83 changes: 83 additions & 0 deletions src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.*;


/**
* Exedute SQL `case when` semantic.
* If there are multiple branches and each branch uses scalar to generator value,
* then it's fast to use this class because it does not generate temp string columns.
*
* E.g.:
* SQL is:
* select
* case
* when bool_1_expr then "value_1"
* when bool_2_expr then "value_2"
* when bool_3_expr then "value_3"
* else "value_else"
* end
* from tab
*
* Execution steps:
* Execute bool exprs to get bool columns, e.g., gets:
* bool column 1: [true, false, false, false] // bool_1_expr result
* bool column 2: [false, true, false, flase] // bool_2_expr result
* bool column 3: [false, false, true, flase] // bool_3_expr result
* Execute `selectFirstTrueIndex` to get the column index for the first true in bool columns.
* Generate a column to store salars: "value_1", "value_2", "value_3", "value_else"
* Execute `Table.gather` to generate the final output column
*
*/
public class CaseWhen {

/**
*
* Select the column index for the first true in bool columns.
* For the row does not contain true, use end index(number of columns).
*
* e.g.:
* column 0: true, false, false, false
* column 1: false, true, false, false
* column 2: false, false, true, false
*
* 1st row is: true, flase, false; first true index is 0
* 2nd row is: false, true, false; first true index is 1
* 3rd row is: false, flase, true; first true index is 2
* 4th row is: false, false, false; do not find true, set index to the end index 3
*
* output column: 0, 1, 2, 3
* In the `case when` context, here 3 index means using NULL value.
*
*/
public static ColumnVector selectFirstTrueIndex(ColumnVector[] boolColumns) {
for (ColumnVector cv : boolColumns) {
assert(cv.getType().equals(DType.BOOL8)) : "Columns must be bools";
}

long[] boolHandles = new long[boolColumns.length];
for (int i = 0; i < boolColumns.length; ++i) {
boolHandles[i] = boolColumns[i].getNativeView();
}

return new ColumnVector(selectFirstTrueIndex(boolHandles));
}

private static native long selectFirstTrueIndex(long[] boolHandles);
}
64 changes: 64 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/CaseWhenTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.*;

import org.junit.jupiter.api.Test;

import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;

public class CaseWhenTest {

@Test
void selectIndexTest() {
try (
ColumnVector b0 = ColumnVector.fromBooleans(
true, false, false, false);
ColumnVector b1 = ColumnVector.fromBooleans(
true, true, false, false);
ColumnVector b2 = ColumnVector.fromBooleans(
false, false, true, false);
ColumnVector b3 = ColumnVector.fromBooleans(
true, true, true, false);
ColumnVector expected = ColumnVector.fromInts(0, 1, 2, 4)) {
ColumnVector[] boolColumns = new ColumnVector[] { b0, b1, b2, b3 };
try (ColumnVector actual = CaseWhen.selectFirstTrueIndex(boolColumns)) {
assertColumnsAreEqual(expected, actual);
}
}
}

@Test
void selectIndexTestWithNull() {
try (
ColumnVector b0 = ColumnVector.fromBoxedBooleans(
null, false, false, null, false);
ColumnVector b1 = ColumnVector.fromBoxedBooleans(
null, null, false, true, true);
ColumnVector b2 = ColumnVector.fromBoxedBooleans(
null, null, false, true, false);
ColumnVector b3 = ColumnVector.fromBoxedBooleans(
null, null, null, true, null);
ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 1, 1)) {
ColumnVector[] boolColumns = new ColumnVector[] { b0, b1, b2, b3 };
try (ColumnVector actual = CaseWhen.selectFirstTrueIndex(boolColumns)) {
assertColumnsAreEqual(expected, actual);
}
}
}
}
Loading