Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

case when improvement: avoid copy_if_else #2079

Merged
merged 10 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ set(CUDFJNI_INCLUDE_DIRS
add_library(
spark_rapids_jni SHARED
src/BloomFilterJni.cpp
src/CaseWhenJni.cpp
src/CastStringJni.cpp
src/DateTimeRebaseJni.cpp
src/DecimalUtilsJni.cpp
Expand All @@ -167,6 +168,7 @@ add_library(
src/SparkResourceAdaptorJni.cpp
src/ZOrderJni.cpp
src/bloom_filter.cu
src/case_when.cu
src/cast_decimal_to_string.cu
src/format_float.cu
src/cast_float_to_string.cu
Expand Down
55 changes: 55 additions & 0 deletions src/main/cpp/src/CaseWhenJni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "case_when.hpp"
#include "cudf_jni_apis.hpp"

#include <vector>

extern "C" {

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CaseWhen_selectFirstTrueIndex(
JNIEnv* env, jclass, jlongArray bool_cols)
{
JNI_NULL_CHECK(env, bool_cols, "array of column handles is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::jni::native_jpointerArray<cudf::column_view> n_cudf_bool_columns(env, bool_cols);
auto bool_column_views = n_cudf_bool_columns.get_dereferenced();
return cudf::jni::release_as_jlong(
spark_rapids_jni::select_first_true_index(cudf::table_view(bool_column_views)));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CaseWhen_selectFromIndex(JNIEnv* env,
jclass,
jlong scalar_cols,
jlong index_col)
{
JNI_NULL_CHECK(env, scalar_cols, "Column handles is null", 0);
JNI_NULL_CHECK(env, index_col, "Column handles is null", 0);
try {
cudf::jni::auto_set_device(env);
auto const scalar_column_view = reinterpret_cast<cudf::column_view const*>(scalar_cols);
auto const scalar_strings_col_view = cudf::strings_column_view{*scalar_column_view};
auto const index_column_view = reinterpret_cast<cudf::column_view const*>(index_col);
return cudf::jni::release_as_jlong(
spark_rapids_jni::select_from_index(scalar_strings_col_view, *index_column_view));
}
CATCH_STD(env, 0);
}
}
151 changes: 151 additions & 0 deletions src/main/cpp/src/case_when.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "case_when.hpp"

#include <cudf/column/column_factories.hpp>
#include <cudf/strings/detail/strings_column_factories.cuh>
#include <cudf/table/table_device_view.cuh>
#include <cudf/types.hpp>

#include <thrust/transform.h>

namespace spark_rapids_jni {
namespace detail {

/**
* Select the column index for the first true in bool columns for the specified row
*/
struct select_first_true_fn {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We should wrap anything that is locally used in a source file into an anonymous namespace to avoid name clashing in the future with other source files.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
Added anonymous namespace

// bool columns stores the results of executing `when` expressions
cudf::table_device_view const d_table;

/**
* The number of bool columns is the size of case when branches.
* Note: reuturned index may be out of bound, valid bound is [0, col_num)
* When returning col_num index, it means final result is NULL value or ELSE value.
*
* e.g.:
* CASE WHEN 'a' THEN 'A' END
* The number of bool columns is 1
* The number of scalars is 1
* Max index is 1 which means using NULL(all when exprs are false).
* CASE WHEN 'a' THEN 'A' ELSE '_' END
* The number of bool columns is 1
* The number of scalars is 2
* Max index is also 1 which means using else value '_'
*/
__device__ cudf::size_type operator()(std::size_t row_idx)
{
auto col_num = d_table.num_columns();
bool found_true = false;
cudf::size_type first_true_index = col_num;
for (auto col_idx = 0; !found_true && col_idx < col_num; col_idx++) {
auto const& col = d_table.column(col_idx);
if (!col.is_null(row_idx) && col.element<bool>(row_idx)) {
// Predicate is true and not null
found_true = true;
first_true_index = col_idx;
}
}
return first_true_index;
}
};

std::unique_ptr<cudf::column> select_first_true_index(cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
// checks
auto const num_columns = when_bool_columns.num_columns();
CUDF_EXPECTS(num_columns > 0, "At least one column must be specified");
auto const row_count = when_bool_columns.num_rows();
if (row_count == 0) // empty begets empty
return cudf::make_empty_column(cudf::type_id::INT32);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (row_count == 0) // empty begets empty
return cudf::make_empty_column(cudf::type_id::INT32);
if (row_count == 0) { // empty begets empty
return cudf::make_empty_column(cudf::type_id::INT32);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


// make output column
auto ret = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, row_count, cudf::mask_state::ALL_VALID, stream, mr);

// select first true index
auto d_table = cudf::table_device_view::create(when_bool_columns, stream);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto d_table = cudf::table_device_view::create(when_bool_columns, stream);
auto const d_table_ptr = cudf::table_device_view::create(when_bool_columns, stream);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(row_count),
ret->mutable_view().begin<cudf::size_type>(),
select_first_true_fn{*d_table});
return ret;
}

std::unique_ptr<cudf::column> select_from_index(
cudf::strings_column_view const& then_and_else_scalar_column,
cudf::column_view const& select_index_column,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
cudf::size_type num_of_rows = select_index_column.size();
cudf::size_type num_of_scalar = then_and_else_scalar_column.size();

// create device views
auto d_scalars =
*(cudf::column_device_view::create(then_and_else_scalar_column.parent(), stream));
auto d_select_index = *(cudf::column_device_view::create(select_index_column, stream));

// Select <str_ptr, str_size> pairs from multiple scalars according to select index
using str_view = thrust::pair<char const*, cudf::size_type>;
rmm::device_uvector<str_view> indices(num_of_rows, stream);
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(num_of_rows),
indices.begin(),
[d_scalars, num_of_scalar, d_select_index] __device__(cudf::size_type row_idx) {
// select scalar according to index
cudf::size_type scalar_idx = d_select_index.element<cudf::size_type>(row_idx);

// return <str_ptr, str_size> pair
if (scalar_idx < num_of_scalar && !d_scalars.is_null(scalar_idx)) {
auto const d_str = d_scalars.element<cudf::string_view>(scalar_idx);
return str_view{d_str.data(), d_str.size_bytes()};
} else {
// index is out of bound, use NULL, for more details refer to comments in
// `select_first_true_fn`
return str_view{nullptr, 0};
}
});

// create final string column from string index pairs
return cudf::strings::detail::make_strings_column(indices.begin(), indices.end(), stream, mr);
}

} // namespace detail

std::unique_ptr<cudf::column> select_first_true_index(cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return detail::select_first_true_index(when_bool_columns, stream, mr);
}

std::unique_ptr<cudf::column> select_from_index(
cudf::strings_column_view const& then_and_else_scalars_column,
cudf::column_view const& select_index_column,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return detail::select_from_index(then_and_else_scalars_column, select_index_column, stream, mr);
}

} // namespace spark_rapids_jni
76 changes: 76 additions & 0 deletions src/main/cpp/src/case_when.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/column/column.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>

#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <map>
#include <memory>
#include <string>
#include <variant>
#include <vector>

namespace spark_rapids_jni {

/**
*
* Select the column index for the first true in bool columns.
* For the row does not contain true, use end index(number of columns).
*
* e.g.:
* column 0 in table: true, false, false, false
* column 1 in table: false, true, false, false
* column 2 in table: false, false, true, false
*
* 1st row is: true, flase, false; first true index is 0
* 2nd row is: false, true, false; first true index is 1
* 3rd row is: false, flase, false; first true index is 2
* 4th row is: false, false, false; do not find true, set index to the end index 3
*
* output column: 0, 1, 2, 3
* In the `case when` context, here 3 index means using NULL value.
*
*/
std::unique_ptr<cudf::column> select_first_true_index(
cudf::table_view const& when_bool_columns,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
*
* Select strings in scalar column according to index column.
* If index is out of bound, use NULL value
* e.g.:
* scalar column: s0, s1, s2
* index column: 0, 1, 2, 2, 1, 0, 3
* output column: s0, s1, s2, s2, s1, s0, NULL
*
*/
std::unique_ptr<cudf::column> select_from_index(
cudf::strings_column_view const& then_and_else_scalar_column,
cudf::column_view const& select_index_column,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
Copy link
Collaborator

@ttnghia ttnghia Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait. I just realize that this is just a gather. So we don't need this function at all. Just call cudf::gather (through Java_ai_rapids_cudf_Table_gather), which already supports all data types.


} // namespace spark_rapids_jni
Loading
Loading