Skip to content

Commit

Permalink
Fix transforms both on legacy and non-legacy TimestampType column
Browse files Browse the repository at this point in the history
  • Loading branch information
hantangwangd committed Feb 21, 2024
1 parent 6998a73 commit 4665527
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

import java.math.BigDecimal;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.Decimals.encodeScaledValue;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
Expand Down Expand Up @@ -132,7 +134,12 @@ else if (type.getJavaType() == double.class) {
type.writeDouble(blockBuilder, ((Number) value).doubleValue());
}
else if (type.getJavaType() == long.class) {
type.writeLong(blockBuilder, ((Number) value).longValue());
if (value instanceof BigDecimal) {
type.writeLong(blockBuilder, ((BigDecimal) value).unscaledValue().longValue());
}
else {
type.writeLong(blockBuilder, ((Number) value).longValue());
}
}
else if (type.getJavaType() == Slice.class) {
Slice slice;
Expand All @@ -142,6 +149,9 @@ else if (type.getJavaType() == Slice.class) {
else if (value instanceof String) {
slice = Slices.utf8Slice((String) value);
}
else if (value instanceof BigDecimal) {
slice = encodeScaledValue((BigDecimal) value);
}
else {
slice = (Slice) value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types.DecimalType;

import java.io.IOException;
import java.io.StringWriter;
Expand Down Expand Up @@ -168,7 +169,7 @@ public static Object getValue(JsonNode partitionValue, Type type)
throw new UncheckedIOException("Failed during JSON conversion of " + partitionValue, e);
}
case DECIMAL:
return partitionValue.decimalValue();
return partitionValue.decimalValue().setScale(((DecimalType) type).scale());
}
throw new UnsupportedOperationException("Type not supported as partition column: " + type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,147 @@ private void testCreatePartitionedTableAs(Session session, FileFormat fileFormat
dropTable(session, "test_create_partitioned_table_as_" + fileFormat.toString().toLowerCase(ENGLISH));
}

@Test
public void testPartitionOnDecimalColumn()
{
testWithAllFileFormats(this::testPartitionedByShortDecimalType);
testWithAllFileFormats(this::testPartitionedByLongDecimalType);
testWithAllFileFormats(this::testTruncateShortDecimalTransform);
testWithAllFileFormats(this::testTruncateLongDecimalTransform);
}

public void testPartitionedByShortDecimalType(Session session, FileFormat format)
{
// create iceberg table partitioned by column of ShortDecimalType, and insert some data
assertUpdate(session, "drop table if exists test_partition_columns_short_decimal");
assertUpdate(session, format("create table test_partition_columns_short_decimal(a bigint, b decimal(9, 2))" +
" with (format = '%s', partitioning = ARRAY['b'])", format.name()));
assertUpdate(session, "insert into test_partition_columns_short_decimal values(1, 12.31), (2, 133.28)", 2);
assertQuery(session, "select * from test_partition_columns_short_decimal", "values(1, 12.31), (2, 133.28)");

// validate column of ShortDecimalType exists in query filter
assertQuery(session, "select * from test_partition_columns_short_decimal where b = 133.28", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_short_decimal where b = 12.31", "values(1, 12.31)");

// validate column of ShortDecimalType in system table "partitions"
assertQuery(session, "select b, row_count from \"test_partition_columns_short_decimal$partitions\"", "values(12.31, 1), (133.28, 1)");

// validate column of TimestampType exists in delete filter
assertUpdate(session, "delete from test_partition_columns_short_decimal WHERE b = 12.31", 1);
assertQuery(session, "select * from test_partition_columns_short_decimal", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_short_decimal where b = 133.28", "values(2, 133.28)");

assertQuery(session, "select b, row_count from \"test_partition_columns_short_decimal$partitions\"", "values(133.28, 1)");

assertUpdate(session, "drop table test_partition_columns_short_decimal");
}

public void testPartitionedByLongDecimalType(Session session, FileFormat format)
{
// create iceberg table partitioned by column of ShortDecimalType, and insert some data
assertUpdate(session, "drop table if exists test_partition_columns_long_decimal");
assertUpdate(session, format("create table test_partition_columns_long_decimal(a bigint, b decimal(20, 2))" +
" with (format = '%s', partitioning = ARRAY['b'])", format.name()));
assertUpdate(session, "insert into test_partition_columns_long_decimal values(1, 11111111111111112.31), (2, 133.28)", 2);
assertQuery(session, "select * from test_partition_columns_long_decimal", "values(1, 11111111111111112.31), (2, 133.28)");

// validate column of ShortDecimalType exists in query filter
assertQuery(session, "select * from test_partition_columns_long_decimal where b = 133.28", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_long_decimal where b = 11111111111111112.31", "values(1, 11111111111111112.31)");

// validate column of ShortDecimalType in system table "partitions"
assertQuery(session, "select b, row_count from \"test_partition_columns_long_decimal$partitions\"",
"values(11111111111111112.31, 1), (133.28, 1)");

// validate column of TimestampType exists in delete filter
assertUpdate(session, "delete from test_partition_columns_long_decimal WHERE b = 11111111111111112.31", 1);
assertQuery(session, "select * from test_partition_columns_long_decimal", "values(2, 133.28)");
assertQuery(session, "select * from test_partition_columns_long_decimal where b = 133.28", "values(2, 133.28)");

assertQuery(session, "select b, row_count from \"test_partition_columns_long_decimal$partitions\"",
"values(133.28, 1)");

assertUpdate(session, "drop table test_partition_columns_long_decimal");
}

public void testTruncateShortDecimalTransform(Session session, FileFormat format)
{
assertUpdate(session, format("CREATE TABLE test_truncate_decimal_transform (d DECIMAL(9, 2), b BIGINT)" +
" WITH (format = '%s', partitioning = ARRAY['truncate(d, 10)'])", format.name()));
String select = "SELECT d_trunc, row_count, d.min, d.max FROM \"test_truncate_decimal_transform$partitions\"";

assertUpdate(session, "INSERT INTO test_truncate_decimal_transform VALUES" +
"(NULL, 101)," +
"(12.34, 1)," +
"(12.30, 2)," +
"(12.29, 3)," +
"(0.05, 4)," +
"(-0.05, 5)", 6);

assertQuery(session, "SELECT d_trunc FROM \"test_truncate_decimal_transform$partitions\"", "VALUES NULL, 12.30, 12.20, 0.00, -0.10");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d IN (12.34, 12.30)", "VALUES 1, 2");
assertQuery(session, select + " WHERE d_trunc = 12.30",
"VALUES (12.30, 2, 12.30, 12.34)");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = 12.29", "VALUES 3");
assertQuery(session, select + " WHERE d_trunc = 12.20",
"VALUES (12.20, 1, 12.29, 12.29)");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = 0.05", "VALUES 4");
assertQuery(session, select + " WHERE d_trunc = 0.00",
"VALUES (0.00, 1, 0.05, 0.05)");

assertQuery(session, "SELECT b FROM test_truncate_decimal_transform WHERE d = -0.05", "VALUES 5");
assertQuery(session, select + " WHERE d_trunc = -0.10",
"VALUES (-0.10, 1, -0.05, -0.05)");

// Exercise IcebergMetadata.applyFilter with non-empty Constraint.predicate, via non-pushdownable predicates
assertQuery(session, "SELECT * FROM test_truncate_decimal_transform WHERE d * 100 % 10 = 9 AND b % 7 = 3",
"VALUES (12.29, 3)");

assertUpdate(session, "DROP TABLE test_truncate_decimal_transform");
}

public void testTruncateLongDecimalTransform(Session session, FileFormat format)
{
assertUpdate(session, format("CREATE TABLE test_truncate_long_decimal_transform (d DECIMAL(20, 2), b BIGINT)" +
" WITH (format = '%s', partitioning = ARRAY['truncate(d, 10)'])", format.name()));
String select = "SELECT d_trunc, row_count, d.min, d.max FROM \"test_truncate_long_decimal_transform$partitions\"";

assertUpdate(session, "INSERT INTO test_truncate_long_decimal_transform VALUES" +
"(NULL, 101)," +
"(12.34, 1)," +
"(12.30, 2)," +
"(11111111111111112.29, 3)," +
"(0.05, 4)," +
"(-0.05, 5)", 6);

assertQuery(session, "SELECT d_trunc FROM \"test_truncate_long_decimal_transform$partitions\"", "VALUES NULL, 12.30, 11111111111111112.20, 0.00, -0.10");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d IN (12.34, 12.30)", "VALUES 1, 2");
assertQuery(session, select + " WHERE d_trunc = 12.30",
"VALUES (12.30, 2, 12.30, 12.34)");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = 11111111111111112.29", "VALUES 3");
assertQuery(session, select + " WHERE d_trunc = 11111111111111112.20",
"VALUES (11111111111111112.20, 1, 11111111111111112.29, 11111111111111112.29)");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = 0.05", "VALUES 4");
assertQuery(session, select + " WHERE d_trunc = 0.00",
"VALUES (0.00, 1, 0.05, 0.05)");

assertQuery(session, "SELECT b FROM test_truncate_long_decimal_transform WHERE d = -0.05", "VALUES 5");
assertQuery(session, select + " WHERE d_trunc = -0.10",
"VALUES (-0.10, 1, -0.05, -0.05)");

// Exercise IcebergMetadata.applyFilter with non-empty Constraint.predicate, via non-pushdownable predicates
assertQuery(session, "SELECT * FROM test_truncate_long_decimal_transform WHERE d * 100 % 10 = 9 AND b % 7 = 3",
"VALUES (11111111111111112.29, 3)");

assertUpdate(session, "DROP TABLE test_truncate_long_decimal_transform");
}

@Test
public void testColumnComments()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
package com.facebook.presto.spi;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.Type;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -28,6 +30,7 @@
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.Decimals.encodeScaledValue;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
Expand Down Expand Up @@ -118,7 +121,15 @@ public long getLong(int field)
{
checkState(record != null, "no current record");
requireNonNull(record.get(field), "value is null");
return ((Number) record.get(field)).longValue();
Object value = record.get(field);
if (value instanceof BigDecimal) {
checkState(((DecimalType) this.getType(field)).isShort(),
"Expected ShortDecimalType");
return ((BigDecimal) value).unscaledValue().longValue();
}
else {
return ((Number) record.get(field)).longValue();
}
}

@Override
Expand All @@ -144,6 +155,9 @@ public Slice getSlice(int field)
if (value instanceof Slice) {
return (Slice) value;
}
if (value instanceof BigDecimal) {
return encodeScaledValue((BigDecimal) value);
}
throw new IllegalArgumentException("Field " + field + " is not a String, but is a " + value.getClass().getName());
}

Expand Down

0 comments on commit 4665527

Please sign in to comment.