Skip to content

Commit

Permalink
[GLUTEN-5580][CH]Fix cast to int exceed max (#5581)
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)

(Fixes: #5580)

How was this patch tested?
test by ut
  • Loading branch information
KevinyhZou authored May 7, 2024
1 parent 18af4bc commit b96ddb4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2094,13 +2094,15 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-3149: Fix convert exception of Inf to int") {
val tbl_create_sql = "create table test_tbl_3149(a int, b int) using parquet";
val tbl_insert_sql = "insert into test_tbl_3149 values(1, 0)"
val select_sql = "select cast(a * 1.0f/b as int) as x from test_tbl_3149 where a = 1"
test("GLUTEN-3149/GLUTEN-5580: Fix convert float to int") {
val tbl_create_sql = "create table test_tbl_3149(a int, b bigint) using parquet";
val tbl_insert_sql = "insert into test_tbl_3149 values(1, 0), (2, 171396196666200)"
val select_sql_1 = "select cast(a * 1.0f/b as int) as x from test_tbl_3149 where a = 1"
val select_sql_2 = "select cast(b/100 as int) from test_tbl_3149 where a = 2"
spark.sql(tbl_create_sql)
spark.sql(tbl_insert_sql);
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
compareResultsAgainstVanillaSpark(select_sql_1, true, { _ => })
compareResultsAgainstVanillaSpark(select_sql_2, true, { _ => })
spark.sql("drop table test_tbl_3149")
}

Expand Down
27 changes: 15 additions & 12 deletions cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
* limitations under the License.
*/

#include <limits.h>
#include <base/types.h>
#include <base/wide_integer.h>
#include <base/wide_integer_impl.h>
#include <Functions/SparkFunctionCastFloatToInt.h>

using namespace DB;
Expand All @@ -36,18 +39,18 @@ struct NameToInt64 { static constexpr auto name = "sparkCastFloatToInt64"; };
struct NameToInt128 { static constexpr auto name = "sparkCastFloatToInt128"; };
struct NameToInt256 { static constexpr auto name = "sparkCastFloatToInt256"; };

using SparkFunctionCastFloatToInt8 = local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8>;
using SparkFunctionCastFloatToInt16 = local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16>;
using SparkFunctionCastFloatToInt32 = local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32>;
using SparkFunctionCastFloatToInt64 = local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64>;
using SparkFunctionCastFloatToInt128 = local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128>;
using SparkFunctionCastFloatToInt256 = local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256>;
using SparkFunctionCastFloatToUInt8 = local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8>;
using SparkFunctionCastFloatToUInt16 = local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16>;
using SparkFunctionCastFloatToUInt32 = local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32>;
using SparkFunctionCastFloatToUInt64 = local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64>;
using SparkFunctionCastFloatToUInt128 = local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128>;
using SparkFunctionCastFloatToUInt256 = local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256>;
using SparkFunctionCastFloatToInt8 = local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8, INT8_MAX, INT8_MIN>;
using SparkFunctionCastFloatToInt16 = local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16, INT16_MAX, INT16_MIN>;
using SparkFunctionCastFloatToInt32 = local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32, INT32_MAX, INT32_MIN>;
using SparkFunctionCastFloatToInt64 = local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64, INT64_MAX, INT64_MIN>;
using SparkFunctionCastFloatToInt128 = local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128, std::numeric_limits<Int128>::max(), std::numeric_limits<Int128>::min()>;
using SparkFunctionCastFloatToInt256 = local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256, std::numeric_limits<Int256>::max(), std::numeric_limits<Int256>::min()>;
using SparkFunctionCastFloatToUInt8 = local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8, UINT8_MAX, 0>;
using SparkFunctionCastFloatToUInt16 = local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16, UINT16_MAX, 0>;
using SparkFunctionCastFloatToUInt32 = local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32, UINT32_MAX, 0>;
using SparkFunctionCastFloatToUInt64 = local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64, UINT64_MAX, 0>;
using SparkFunctionCastFloatToUInt128 = local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128, std::numeric_limits<UInt128>::max(), 0>;
using SparkFunctionCastFloatToUInt256 = local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256, std::numeric_limits<UInt256>::max(), 0>;

REGISTER_FUNCTION(SparkFunctionCastToInt)
{
Expand Down
12 changes: 6 additions & 6 deletions cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace ErrorCodes
namespace local_engine
{

template <typename T, typename Name>
template <typename T, typename Name, T int_max_value, T int_min_value>
class SparkFunctionCastFloatToInt : public DB::IFunction
{
public:
Expand Down Expand Up @@ -74,7 +74,7 @@ class SparkFunctionCastFloatToInt : public DB::IFunction
DB::ColumnPtr src_col = arguments[0].column;
size_t size = src_col->size();

auto res_col = DB::ColumnVector<T>::create(size);
auto res_col = DB::ColumnVector<T>::create(size, 0);
auto null_map_col = DB::ColumnUInt8::create(size, 0);

switch(removeNullable(arguments[0].type)->getTypeId())
Expand All @@ -101,15 +101,15 @@ class SparkFunctionCastFloatToInt : public DB::IFunction
{
F element = src_vec->getElement(i);
if (isNaN(element) || !isFinite(element))
{
data[i] = 0;
null_map_data[i] = 1;
}
else if (element > int_max_value)
data[i] = int_max_value;
else if (element < int_min_value)
data[i] = int_min_value;
else
data[i] = static_cast<T>(element);
}
}

};

}

0 comments on commit b96ddb4

Please sign in to comment.