Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

case when improvement: avoid copy_if_else #2079

Merged
merged 10 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ set(CUDFJNI_INCLUDE_DIRS
add_library(
spark_rapids_jni SHARED
src/BloomFilterJni.cpp
src/CaseWhenJni.cpp
src/CastStringJni.cpp
src/DateTimeRebaseJni.cpp
src/DecimalUtilsJni.cpp
Expand All @@ -196,6 +197,7 @@ add_library(
src/SparkResourceAdaptorJni.cpp
src/ZOrderJni.cpp
src/bloom_filter.cu
src/case_when.cu
src/cast_decimal_to_string.cu
src/format_float.cu
src/cast_float_to_string.cu
Expand Down
37 changes: 37 additions & 0 deletions src/main/cpp/src/CaseWhenJni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "case_when.hpp"
#include "cudf_jni_apis.hpp"

#include <vector>

extern "C" {

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CaseWhen_selectFirstTrueIndex(
JNIEnv* env, jclass, jlongArray bool_cols)
{
JNI_NULL_CHECK(env, bool_cols, "array of column handles is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_bool_columns(env, bool_cols);
auto bool_column_views = n_cudf_bool_columns.get_dereferenced();
return cudf::jni::release_as_jlong(
spark_rapids_jni::select_first_true_index(cudf::table_view(bool_column_views)));
}
CATCH_STD(env, 0);
}
}
99 changes: 99 additions & 0 deletions src/main/cpp/src/case_when.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "case_when.hpp"

#include <cudf/column/column_factories.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/table/table_device_view.cuh>
#include <cudf/types.hpp>

#include <thrust/transform.h>

namespace spark_rapids_jni {
namespace detail {

/**
* Select the column index for the first true in bool columns for the specified row
*/
struct select_first_true_fn {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We should wrap anything that is locally used in a source file into an anonymous namespace to avoid name clashing in the future with other source files.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
Added anonymous namespace

// bool columns stores the results of executing `when` expressions
cudf::table_device_view const d_table;

/**
* The number of bool columns is the size of case when branches.
* Note: reuturned index may be out of bound, valid bound is [0, col_num)
* When returning col_num index, it means final result is NULL value or ELSE value.
*
* e.g.:
* CASE WHEN 'a' THEN 'A' END
* The number of bool columns is 1
* The number of scalars is 1
* Max index is 1 which means using NULL(all when exprs are false).
* CASE WHEN 'a' THEN 'A' ELSE '_' END
* The number of bool columns is 1
* The number of scalars is 2
* Max index is also 1 which means using else value '_'
*/
__device__ cudf::size_type operator()(std::size_t row_idx) const
{
auto col_num = d_table.num_columns();
for (auto col_idx = 0; col_idx < col_num; col_idx++) {
auto const& col = d_table.column(col_idx);
if (!col.is_null(row_idx) && col.element<bool>(row_idx)) {
// Predicate is true and not null
return col_idx;
}
}
return col_num;
}
};

std::unique_ptr<cudf::column> select_first_true_index(cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// checks
auto const num_columns = when_bool_columns.num_columns();
CUDF_EXPECTS(num_columns > 0, "At least one column must be specified");
auto const row_count = when_bool_columns.num_rows();
if (row_count == 0) { // empty begets empty
return cudf::make_empty_column(cudf::type_id::INT32);
}
// make output column
auto ret = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, row_count, cudf::mask_state::ALL_VALID, stream, mr);

// select first true index
auto const d_table_ptr = cudf::table_device_view::create(when_bool_columns, stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(row_count),
ret->mutable_view().begin<cudf::size_type>(),
select_first_true_fn{*d_table_ptr});
return ret;
}

} // namespace detail

std::unique_ptr<cudf::column> select_first_true_index(cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return detail::select_first_true_index(when_bool_columns, stream, mr);
}

} // namespace spark_rapids_jni
60 changes: 60 additions & 0 deletions src/main/cpp/src/case_when.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/column/column.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>

#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <map>
#include <memory>
#include <string>
#include <variant>
#include <vector>

namespace spark_rapids_jni {

/**
*
* Select the column index for the first true in bool columns.
* For the row does not contain true, use end index(number of columns).
*
* e.g.:
* column 0 in table: true, false, false, false
* column 1 in table: false, true, false, false
* column 2 in table: false, false, true, false
*
* 1st row is: true, flase, false; first true index is 0
* 2nd row is: false, true, false; first true index is 1
* 3rd row is: false, flase, true; first true index is 2
* 4th row is: false, false, false; do not find true, set index to the end index 3
*
* output column: 0, 1, 2, 3
* In the `case when` context, here 3 index means using NULL value.
*
*/
std::unique_ptr<cudf::column> select_first_true_index(
cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace spark_rapids_jni
83 changes: 83 additions & 0 deletions src/main/java/com/nvidia/spark/rapids/jni/CaseWhen.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.*;


/**
* Exedute SQL `case when` semantic.
* If there are multiple branches and each branch uses scalar to generator value,
* then it's fast to use this class because it does not generate temp string columns.
*
* E.g.:
* SQL is:
* select
* case
* when bool_1_expr then "value_1"
* when bool_2_expr then "value_2"
* when bool_3_expr then "value_3"
* else "value_else"
* end
* from tab
*
* Execution steps:
* Execute bool exprs to get bool columns, e.g., gets:
* bool column 1: [true, false, false, false] // bool_1_expr result
* bool column 2: [false, true, false, flase] // bool_2_expr result
* bool column 3: [false, false, true, flase] // bool_3_expr result
* Execute `selectFirstTrueIndex` to get the column index for the first true in bool columns.
* Generate a column to store salars: "value_1", "value_2", "value_3", "value_else"
* Execute `Table.gather` to generate the final output column
*
*/
public class CaseWhen {

/**
*
* Select the column index for the first true in bool columns.
* For the row does not contain true, use end index(number of columns).
*
* e.g.:
* column 0: true, false, false, false
* column 1: false, true, false, false
* column 2: false, false, true, false
*
* 1st row is: true, flase, false; first true index is 0
* 2nd row is: false, true, false; first true index is 1
* 3rd row is: false, flase, true; first true index is 2
* 4th row is: false, false, false; do not find true, set index to the end index 3
*
* output column: 0, 1, 2, 3
* In the `case when` context, here 3 index means using NULL value.
*
*/
public static ColumnVector selectFirstTrueIndex(ColumnVector[] boolColumns) {
for (ColumnVector cv : boolColumns) {
assert(cv.getType().equals(DType.BOOL8)) : "Columns must be bools";
}

long[] boolHandles = new long[boolColumns.length];
for (int i = 0; i < boolColumns.length; ++i) {
boolHandles[i] = boolColumns[i].getNativeView();
}

return new ColumnVector(selectFirstTrueIndex(boolHandles));
}

private static native long selectFirstTrueIndex(long[] boolHandles);
}
64 changes: 64 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/CaseWhenTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.*;

import org.junit.jupiter.api.Test;

import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;

public class CaseWhenTest {

@Test
void selectIndexTest() {
try (
ColumnVector b0 = ColumnVector.fromBooleans(
true, false, false, false);
ColumnVector b1 = ColumnVector.fromBooleans(
true, true, false, false);
ColumnVector b2 = ColumnVector.fromBooleans(
false, false, true, false);
ColumnVector b3 = ColumnVector.fromBooleans(
true, true, true, false);
ColumnVector expected = ColumnVector.fromInts(0, 1, 2, 4)) {
ColumnVector[] boolColumns = new ColumnVector[] { b0, b1, b2, b3 };
try (ColumnVector actual = CaseWhen.selectFirstTrueIndex(boolColumns)) {
assertColumnsAreEqual(expected, actual);
}
}
}

@Test
void selectIndexTestWithNull() {
try (
ColumnVector b0 = ColumnVector.fromBoxedBooleans(
null, false, false, null, false);
ColumnVector b1 = ColumnVector.fromBoxedBooleans(
null, null, false, true, true);
ColumnVector b2 = ColumnVector.fromBoxedBooleans(
null, null, false, true, false);
ColumnVector b3 = ColumnVector.fromBoxedBooleans(
null, null, null, true, null);
ColumnVector expected = ColumnVector.fromInts(4, 4, 4, 1, 1)) {
ColumnVector[] boolColumns = new ColumnVector[] { b0, b1, b2, b3 };
try (ColumnVector actual = CaseWhen.selectFirstTrueIndex(boolColumns)) {
assertColumnsAreEqual(expected, actual);
}
}
}
}
Loading