From 4665527ac81b7689c934ed604ae54eab970ef710 Mon Sep 17 00:00:00 2001 From: wangd Date: Wed, 21 Feb 2024 18:15:54 +0800 Subject: [PATCH] Fix transforms both on legacy and non-legacy TimestampType column --- .../presto/common/type/TypeUtils.java | 12 +- .../presto/iceberg/PartitionData.java | 3 +- .../IcebergDistributedSmokeTestBase.java | 141 ++++++++++++++++++ .../presto/spi/InMemoryRecordSet.java | 16 +- 4 files changed, 169 insertions(+), 3 deletions(-) diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/TypeUtils.java b/presto-common/src/main/java/com/facebook/presto/common/type/TypeUtils.java index accd51f81b3c1..1de251ac68842 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/type/TypeUtils.java +++ b/presto-common/src/main/java/com/facebook/presto/common/type/TypeUtils.java @@ -19,6 +19,7 @@ import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import java.math.BigDecimal; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -26,6 +27,7 @@ import java.util.stream.Collectors; import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.Decimals.encodeScaledValue; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.RealType.REAL; @@ -132,7 +134,12 @@ else if (type.getJavaType() == double.class) { type.writeDouble(blockBuilder, ((Number) value).doubleValue()); } else if (type.getJavaType() == long.class) { - type.writeLong(blockBuilder, ((Number) value).longValue()); + if (value instanceof BigDecimal) { + type.writeLong(blockBuilder, ((BigDecimal) value).unscaledValue().longValue()); + } + else { + type.writeLong(blockBuilder, ((Number) value).longValue()); + } } else if (type.getJavaType() == Slice.class) { Slice slice; @@ -142,6 +149,9 @@ else if (type.getJavaType() == Slice.class) { else if (value instanceof String) { slice = Slices.utf8Slice((String) value); } + else if (value instanceof BigDecimal) { + slice = encodeScaledValue((BigDecimal) value); + } else { slice = (Slice) value; } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java index 226157405f092..ffffcba9f65a4 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/PartitionData.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.iceberg.StructLike; import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types.DecimalType; import java.io.IOException; import java.io.StringWriter; @@ -168,7 +169,7 @@ public static Object getValue(JsonNode partitionValue, Type type) throw new UncheckedIOException("Failed during JSON conversion of " + partitionValue, e); } case DECIMAL: - return partitionValue.decimalValue(); + return partitionValue.decimalValue().setScale(((DecimalType) type).scale()); } throw new UnsupportedOperationException("Type not supported as partition column: " + type); } diff --git a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java index efa8563feb519..31feef376d11e 100644 --- a/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java +++ b/presto-iceberg/src/test/java/com/facebook/presto/iceberg/IcebergDistributedSmokeTestBase.java @@ -397,6 +397,147 @@ private void testCreatePartitionedTableAs(Session session, FileFormat fileFormat dropTable(session, "test_create_partitioned_table_as_" + fileFormat.toString().toLowerCase(ENGLISH)); } + @Test + public void testPartitionOnDecimalColumn() + { + testWithAllFileFormats(this::testPartitionedByShortDecimalType); + testWithAllFileFormats(this::testPartitionedByLongDecimalType); + testWithAllFileFormats(this::testTruncateShortDecimalTransform); + testWithAllFileFormats(this::testTruncateLongDecimalTransform); + } + + public void testPartitionedByShortDecimalType(Session session, FileFormat format) + { + // create iceberg table partitioned by column of ShortDecimalType, and insert some data + assertUpdate(session, "drop table if exists test_partition_columns_short_decimal"); + assertUpdate(session, format("create table test_partition_columns_short_decimal(a bigint, b decimal(9, 2))" + + " with (format = '%s', partitioning = ARRAY['b'])", format.name())); + assertUpdate(session, "insert into test_partition_columns_short_decimal values(1, 12.31), (2, 133.28)", 2); + assertQuery(session, "select * from test_partition_columns_short_decimal", "values(1, 12.31), (2, 133.28)"); + + // validate column of ShortDecimalType exists in query filter + assertQuery(session, "select * from test_partition_columns_short_decimal where b = 133.28", "values(2, 133.28)"); + assertQuery(session, "select * from test_partition_columns_short_decimal where b = 12.31", "values(1, 12.31)"); + + // validate column of ShortDecimalType in system table "partitions" + assertQuery(session, "select b, row_count from \"test_partition_columns_short_decimal$partitions\"", "values(12.31, 1), (133.28, 1)"); + + // validate column of TimestampType exists in delete filter + assertUpdate(session, "delete from test_partition_columns_short_decimal WHERE b = 12.31", 1); + assertQuery(session, "select * from test_partition_columns_short_decimal", "values(2, 133.28)"); + assertQuery(session, "select * from test_partition_columns_short_decimal where b = 133.28", "values(2, 133.28)"); + + assertQuery(session, "select b, row_count from \"test_partition_columns_short_decimal$partitions\"", "values(133.28, 1)"); + + assertUpdate(session, "drop table test_partition_columns_short_decimal"); + } + + public void testPartitionedByLongDecimalType(Session session, FileFormat format) + { + // create iceberg table partitioned by column of ShortDecimalType, and insert some data + assertUpdate(session, "drop table if exists test_partition_columns_long_decimal"); + assertUpdate(session, format("create table test_partition_columns_long_decimal(a bigint, b decimal(20, 2))" + + " with (format = '%s', partitioning = ARRAY['b'])", format.name())); + assertUpdate(session, "insert into test_partition_columns_long_decimal values(1, 11111111111111112.31), (2, 133.28)", 2); + assertQuery(session, "select * from test_partition_columns_long_decimal", "values(1, 11111111111111112.31), (2, 133.28)"); + + // validate column of ShortDecimalType exists in query filter + assertQuery(session, "select * from test_partition_columns_long_decimal where b = 133.28", "values(2, 133.28)"); + assertQuery(session, "select * from test_partition_columns_long_decimal where b = 11111111111111112.31", "values(1, 11111111111111112.31)"); + + // validate column of ShortDecimalType in system table "partitions" + assertQuery(session, "select b, row_count from \"test_partition_columns_long_decimal$partitions\"", + "values(11111111111111112.31, 1), (133.28, 1)"); + + // validate column of TimestampType exists in delete filter + assertUpdate(session, "delete from test_partition_columns_long_decimal WHERE b = 11111111111111112.31", 1); + assertQuery(session, "select * from test_partition_columns_long_decimal", "values(2, 133.28)"); + assertQuery(session, "select * from test_partition_columns_long_decimal where b = 133.28", "values(2, 133.28)"); + + assertQuery(session, "select b, row_count from \"test_partition_columns_long_decimal$partitions\"", + "values(133.28, 1)"); + + assertUpdate(session, "drop table test_partition_columns_long_decimal"); + } + + public void testTruncateShortDecimalTransform(Session session, FileFormat format) + { + assertUpdate(session, format("CREATE TABLE test_truncate_decimal_transform (d DECIMAL(9, 2), b BIGINT)" + + " WITH (format = '%s', partitioning = ARRAY['truncate(d, 10)'])", format.name())); + String select = "SELECT d_trunc, row_count, d.min, d.max FROM \"test_truncate_decimal_transform$partitions\""; + + assertUpdate(session, "INSERT INTO test_truncate_decimal_transform VALUES" + + "(NULL, 101)," + + "(12.34, 1)," + + "(12.30, 2)," + + "(12.29, 3)," + + "(0.05, 4)," + + "(-0.05, 5)", 6); + + assertQuery(session, "SELECT d_trunc FROM \"test_truncate_decimal_transform$partitions\"", "VALUES NULL, 12.30, 12.20, 0.00, -0.10"); + + assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d IN (12.34, 12.30)", "VALUES 1, 2"); + assertQuery(session, select + " WHERE d_trunc = 12.30", + "VALUES (12.30, 2, 12.30, 12.34)"); + + assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = 12.29", "VALUES 3"); + assertQuery(session, select + " WHERE d_trunc = 12.20", + "VALUES (12.20, 1, 12.29, 12.29)"); + + assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = 0.05", "VALUES 4"); + assertQuery(session, select + " WHERE d_trunc = 0.00", + "VALUES (0.00, 1, 0.05, 0.05)"); + + assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = -0.05", "VALUES 5"); + assertQuery(session, select + " WHERE d_trunc = -0.10", + "VALUES (-0.10, 1, -0.05, -0.05)"); + + // Exercise IcebergMetadata.applyFilter with non-empty Constraint.predicate, via non-pushdownable predicates + assertQuery(session, "SELECT * FROM test_truncate_decimal_transform WHERE d * 100 % 10 = 9 AND b % 7 = 3", + "VALUES (12.29, 3)"); + + assertUpdate(session, "DROP TABLE test_truncate_decimal_transform"); + } + + public void testTruncateLongDecimalTransform(Session session, FileFormat format) + { + assertUpdate(session, format("CREATE TABLE test_truncate_long_decimal_transform (d DECIMAL(20, 2), b BIGINT)" + + " WITH (format = '%s', partitioning = ARRAY['truncate(d, 10)'])", format.name())); + String select = "SELECT d_trunc, row_count, d.min, d.max FROM \"test_truncate_long_decimal_transform$partitions\""; + + assertUpdate(session, "INSERT INTO test_truncate_long_decimal_transform VALUES" + + "(NULL, 101)," + + "(12.34, 1)," + + "(12.30, 2)," + + "(11111111111111112.29, 3)," + + "(0.05, 4)," + + "(-0.05, 5)", 6); + + assertQuery(session, "SELECT d_trunc FROM \"test_truncate_long_decimal_transform$partitions\"", "VALUES NULL, 12.30, 11111111111111112.20, 0.00, -0.10"); + + assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d IN (12.34, 12.30)", "VALUES 1, 2"); + assertQuery(session, select + " WHERE d_trunc = 12.30", + "VALUES (12.30, 2, 12.30, 12.34)"); + + assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = 11111111111111112.29", "VALUES 3"); + assertQuery(session, select + " WHERE d_trunc = 11111111111111112.20", + "VALUES (11111111111111112.20, 1, 11111111111111112.29, 11111111111111112.29)"); + + assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = 0.05", "VALUES 4"); + assertQuery(session, select + " WHERE d_trunc = 0.00", + "VALUES (0.00, 1, 0.05, 0.05)"); + + assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = -0.05", "VALUES 5"); + assertQuery(session, select + " WHERE d_trunc = -0.10", + "VALUES (-0.10, 1, -0.05, -0.05)"); + + // Exercise IcebergMetadata.applyFilter with non-empty Constraint.predicate, via non-pushdownable predicates + assertQuery(session, "SELECT * FROM test_truncate_long_decimal_transform WHERE d * 100 % 10 = 9 AND b % 7 = 3", + "VALUES (11111111111111112.29, 3)"); + + assertUpdate(session, "DROP TABLE test_truncate_long_decimal_transform"); + } + @Test public void testColumnComments() { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/InMemoryRecordSet.java b/presto-spi/src/main/java/com/facebook/presto/spi/InMemoryRecordSet.java index 07f1f28c9cf3f..69bb7a8429cb1 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/InMemoryRecordSet.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/InMemoryRecordSet.java @@ -14,10 +14,12 @@ package com.facebook.presto.spi; import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.Type; import io.airlift.slice.Slice; import io.airlift.slice.Slices; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -28,6 +30,7 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.Decimals.encodeScaledValue; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.IntegerType.INTEGER; import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; @@ -118,7 +121,15 @@ public long getLong(int field) { checkState(record != null, "no current record"); requireNonNull(record.get(field), "value is null"); - return ((Number) record.get(field)).longValue(); + Object value = record.get(field); + if (value instanceof BigDecimal) { + checkState(((DecimalType) this.getType(field)).isShort(), + "Expected ShortDecimalType"); + return ((BigDecimal) value).unscaledValue().longValue(); + } + else { + return ((Number) record.get(field)).longValue(); + } } @Override @@ -144,6 +155,9 @@ public Slice getSlice(int field) if (value instanceof Slice) { return (Slice) value; } + if (value instanceof BigDecimal) { + return encodeScaledValue((BigDecimal) value); + } throw new IllegalArgumentException("Field " + field + " is not a String, but is a " + value.getClass().getName()); }