Skip to content

Commit

Permalink
[Enhancement] Paimon partition pruning support cast predicate
Browse files Browse the repository at this point in the history
Signed-off-by: Jiao Mingye <mxdzs0612@gmail.com>
  • Loading branch information
mxdzs0612 committed Feb 11, 2025
1 parent 53ab97e commit 0e8a814
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

package com.starrocks.connector.paimon;

import com.starrocks.analysis.BoolLiteral;
import com.starrocks.catalog.PrimitiveType;
import com.starrocks.connector.exception.StarRocksConnectorException;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
Expand All @@ -32,8 +34,22 @@
import org.apache.paimon.data.Decimal;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.PredicateBuilder;
import org.apache.paimon.types.BigIntType;
import org.apache.paimon.types.BinaryType;
import org.apache.paimon.types.BooleanType;
import org.apache.paimon.types.CharType;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DateType;
import org.apache.paimon.types.DecimalType;
import org.apache.paimon.types.DoubleType;
import org.apache.paimon.types.FloatType;
import org.apache.paimon.types.IntType;
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.SmallIntType;
import org.apache.paimon.types.TimestampType;
import org.apache.paimon.types.TinyIntType;
import org.apache.paimon.types.VarCharType;

import java.math.BigDecimal;
import java.sql.Timestamp;
Expand All @@ -44,6 +60,7 @@
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.apache.paimon.data.Timestamp.fromSQLTimestamp;
Expand All @@ -52,9 +69,11 @@ public class PaimonPredicateConverter extends ScalarOperatorVisitor<Predicate, V
private static final Logger LOG = LogManager.getLogger(PaimonPredicateConverter.class);
private final PredicateBuilder builder;
private final List<String> fieldNames;
private final List<DataType> fieldTypes;

public PaimonPredicateConverter(RowType rowType) {
this.builder = new PredicateBuilder(rowType);
this.fieldTypes = rowType.getFieldTypes();
this.fieldNames = rowType.getFields().stream().map(DataField::name).collect(Collectors.toList());
}

Expand Down Expand Up @@ -114,10 +133,13 @@ public Predicate visitBinaryPredicate(BinaryPredicateOperator operator, Void con
return null;
}
int idx = fieldNames.indexOf(columnName);
Object literal = getLiteral(operator.getChild(1));
Object literal = getLiteral(operator.getChild(1), fieldTypes.get(idx));
if (literal == null) {
return null;
}
if (fieldTypes.get(idx) instanceof BooleanType) {
literal = convertBoolLiteralValue(literal);
}
switch (operator.getBinaryType()) {
case LT:
return builder.lessThan(idx, literal);
Expand Down Expand Up @@ -146,10 +168,13 @@ public Predicate visitInPredicate(InPredicateOperator operator, Void context) {
List<ScalarOperator> valuesOperatorList = operator.getListChildren();
List<Object> literalValues = new ArrayList<>(valuesOperatorList.size());
for (ScalarOperator valueOperator : valuesOperatorList) {
Object value = getLiteral(valueOperator);
Object value = getLiteral(valueOperator, fieldTypes.get(idx));
if (value == null) {
return null;
}
if (fieldTypes.get(idx) instanceof BooleanType) {
value = convertBoolLiteralValue(value);
}
literalValues.add(value);
}

Expand All @@ -170,7 +195,11 @@ public Predicate visitLikePredicateOperator(LikePredicateOperator operator, Void
int idx = fieldNames.indexOf(columnName);
if (operator.getLikeType() == LikePredicateOperator.LikeType.LIKE) {
if (operator.getChild(1).getType().isStringType()) {
String literal = ((BinaryString) getLiteral(operator.getChild(1))).toString();
Object objectLiteral = getLiteral(operator.getChild(1), fieldTypes.get(idx));
if (objectLiteral == null) {
return null;
}
String literal = ((BinaryString) objectLiteral).toString();
if (literal.length() > 1 && literal.indexOf("%") == literal.length() - 1 && literal.charAt(0) != '%') {
return builder.startsWith(idx, BinaryString.fromString(literal.substring(0, literal.length() - 1)));
}
Expand All @@ -179,48 +208,128 @@ public Predicate visitLikePredicateOperator(LikePredicateOperator operator, Void
return null;
}

private Object getLiteral(ScalarOperator operator) {
if (!(operator instanceof ConstantOperator)) {
private Object getLiteral(ScalarOperator operator, DataType dataType) {
if (operator == null) {
return null;
}

ConstantOperator constValue = (ConstantOperator) operator;
switch (constValue.getType().getPrimitiveType()) {
case BOOLEAN:
return constValue.getBoolean();
case TINYINT:
return constValue.getTinyInt();
case SMALLINT:
return constValue.getSmallint();
case INT:
return constValue.getInt();
case BIGINT:
return constValue.getBigint();
case FLOAT:
return constValue.getFloat();
case DOUBLE:
return constValue.getDouble();
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
BigDecimal bigDecimal = constValue.getDecimal();
PrimitiveType type = constValue.getType().getPrimitiveType();
return Decimal.fromBigDecimal(bigDecimal, PrimitiveType.getMaxPrecisionOfDecimal(type),
PrimitiveType.getDefaultScaleOfDecimal(type));
case HLL:
case VARCHAR:
case CHAR:
return BinaryString.fromString(constValue.getVarchar());
case DATE:
LocalDate localDate = constValue.getDate().toLocalDate();
LocalDate epochDay = Instant.ofEpochSecond(0).atOffset(ZoneOffset.UTC).toLocalDate();
return (int) ChronoUnit.DAYS.between(epochDay, localDate);
case DATETIME:
LocalDateTime localDateTime = constValue.getDatetime();
return fromSQLTimestamp(Timestamp.valueOf((localDateTime)));
default:
return operator.accept(new PaimonPredicateConverter.ExtractLiteralValue(), dataType);
}

private static class ExtractLiteralValue extends ScalarOperatorVisitor<Object, DataType> {
@Override
public Object visit(ScalarOperator scalarOperator, DataType dataType) {
return null;
}

@Override
public Object visitConstant(ConstantOperator operator, DataType dataType) {
if (needCast(operator.getType().getPrimitiveType(), dataType)) {
operator = tryCastToResultType(operator, dataType);
}
if (operator == null) {
return null;
}
switch (operator.getType().getPrimitiveType()) {
case BOOLEAN:
return operator.getBoolean();
case TINYINT:
return operator.getTinyInt();
case SMALLINT:
return operator.getSmallint();
case INT:
return operator.getInt();
case BIGINT:
return operator.getBigint();
case FLOAT:
return operator.getFloat();
case DOUBLE:
return operator.getDouble();
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
BigDecimal bigDecimal = operator.getDecimal();
PrimitiveType type = operator.getType().getPrimitiveType();
return Decimal.fromBigDecimal(bigDecimal, PrimitiveType.getMaxPrecisionOfDecimal(type),
PrimitiveType.getDefaultScaleOfDecimal(type));
case HLL:
case VARCHAR:
case CHAR:
return BinaryString.fromString(operator.getVarchar());
case DATE:
LocalDate localDate = operator.getDate().toLocalDate();
LocalDate epochDay = Instant.ofEpochSecond(0).atOffset(ZoneOffset.UTC).toLocalDate();
return (int) ChronoUnit.DAYS.between(epochDay, localDate);
case DATETIME:
LocalDateTime localDateTime = operator.getDatetime();
return fromSQLTimestamp(Timestamp.valueOf((localDateTime)));
default:
return null;

}
}

@Override
public Object visitCastOperator(CastOperator operator, DataType dataType) {
return operator.getChild(0).accept(this, dataType);
}

private boolean needCast(PrimitiveType sourceType, DataType dataType) {
return switch (sourceType) {
case BOOLEAN -> !(dataType instanceof BooleanType);
case TINYINT -> !(dataType instanceof TinyIntType);
case SMALLINT -> !(dataType instanceof SmallIntType);
case INT -> !(dataType instanceof IntType);
case BIGINT -> !(dataType instanceof BigIntType);
case FLOAT -> !(dataType instanceof FloatType);
case DOUBLE -> !(dataType instanceof DoubleType);
case DECIMALV2, DECIMAL32, DECIMAL64, DECIMAL128 -> !(dataType instanceof DecimalType);
case HLL, VARCHAR -> !(dataType instanceof VarCharType);
case CHAR -> !(dataType instanceof CharType);
case DATE -> !(dataType instanceof DateType);
case DATETIME -> !(dataType instanceof TimestampType);
default -> true;
};
}

private ConstantOperator tryCastToResultType(ConstantOperator operator, DataType dataType) {
Optional<ConstantOperator> res = Optional.empty();
if (dataType instanceof BooleanType) {
res = operator.castTo(com.starrocks.catalog.Type.BOOLEAN);
} else if (dataType instanceof DateType) {
res = operator.castTo(com.starrocks.catalog.Type.DATE);
} else if (dataType instanceof TimestampType) {
res = operator.castTo(com.starrocks.catalog.Type.DATETIME);
} else if (dataType instanceof VarCharType) {
res = operator.castTo(com.starrocks.catalog.Type.STRING);
} else if (dataType instanceof CharType) {
res = operator.castTo(com.starrocks.catalog.Type.CHAR);
} else if (dataType instanceof BinaryType) {
res = operator.castTo(com.starrocks.catalog.Type.VARBINARY);
} else if (dataType instanceof IntType) {
res = operator.castTo(com.starrocks.catalog.Type.INT);
} else if (dataType instanceof BigIntType) {
res = operator.castTo(com.starrocks.catalog.Type.BIGINT);
} else if (dataType instanceof TinyIntType) {
res = operator.castTo(com.starrocks.catalog.Type.TINYINT);
} else if (dataType instanceof SmallIntType) {
res = operator.castTo(com.starrocks.catalog.Type.SMALLINT);
} else if (dataType instanceof FloatType) {
res = operator.castTo(com.starrocks.catalog.Type.FLOAT);
} else if (dataType instanceof DoubleType) {
res = operator.castTo(com.starrocks.catalog.Type.DOUBLE);
}
return res.orElse(operator);
}
}

// Support both 0/1 and true/false for boolean type
private static Object convertBoolLiteralValue(Object literalValue) {
try {
return new BoolLiteral(String.valueOf(literalValue)).getValue();
} catch (Exception e) {
throw new StarRocksConnectorException("Failed to convert %s to boolean type", literalValue);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.starrocks.analysis.BinaryType;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
Expand All @@ -26,6 +27,7 @@
import com.starrocks.sql.optimizer.operator.scalar.LikePredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import org.apache.paimon.data.BinaryString;
import org.apache.paimon.data.Timestamp;
import org.apache.paimon.predicate.And;
import org.apache.paimon.predicate.CompoundPredicate;
import org.apache.paimon.predicate.Equal;
Expand All @@ -40,14 +42,18 @@
import org.apache.paimon.predicate.Or;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.StartsWith;
import org.apache.paimon.types.BooleanType;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DateType;
import org.apache.paimon.types.FloatType;
import org.apache.paimon.types.IntType;
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.TimestampType;
import org.apache.paimon.types.VarCharType;
import org.junit.Assert;
import org.junit.Test;

import java.time.LocalDate;
import java.util.Arrays;
import java.util.List;

Expand All @@ -57,9 +63,16 @@ public class PaimonPredicateConverterTest {
Arrays.asList(
new DataField(0, "f0", new IntType()),
new DataField(1, "f1", new VarCharType()),
new DataField(2, "f2", new FloatType()));
new DataField(2, "f2", new FloatType()),
new DataField(3, "f3", new DateType()),
new DataField(4, "f4", new BooleanType()),
new DataField(5, "f5", new TimestampType()));
private static final ColumnRefOperator F0 = new ColumnRefOperator(0, Type.INT, "f0", true, false);
private static final ColumnRefOperator F1 = new ColumnRefOperator(0, Type.VARCHAR, "f1", true, false);
private static final ColumnRefOperator F1 = new ColumnRefOperator(1, Type.VARCHAR, "f1", true, false);
private static final ColumnRefOperator F2 = new ColumnRefOperator(2, Type.FLOAT, "f2", true, false);
private static final ColumnRefOperator F3 = new ColumnRefOperator(3, Type.DATE, "f3", true, false);
private static final ColumnRefOperator F4 = new ColumnRefOperator(4, Type.BOOLEAN, "f4", true, false);
private static final ColumnRefOperator F5 = new ColumnRefOperator(5, Type.DATETIME, "f5", true, false);
private static final PaimonPredicateConverter CONVERTER = new PaimonPredicateConverter(new RowType(DATA_FIELDS));

@Test
Expand Down Expand Up @@ -255,4 +268,73 @@ public void testBinaryString() {
LeafPredicate leafPredicate = (LeafPredicate) result;
Assert.assertEquals(BinaryString.fromString("ttt"), leafPredicate.literals().get(0));
}

@Test
public void testPaimonCastPredicate() {
// double to int
ConstantOperator doubleValue = ConstantOperator.createDouble(11.11);
CastOperator cast0 = new CastOperator(Type.INT, F0);
Predicate intResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast0, doubleValue));
Assert.assertTrue(intResult instanceof LeafPredicate);
LeafPredicate leafPredicate0 = (LeafPredicate) intResult;
Assert.assertEquals(11, leafPredicate0.literals().get(0));
// string to date
ConstantOperator string = ConstantOperator.createVarchar("2025-01-01");
CastOperator cast1 = new CastOperator(Type.DATE, F1);
Predicate stringResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast1, string));
Assert.assertTrue(stringResult instanceof LeafPredicate);
LeafPredicate leafPredicate1 = (LeafPredicate) stringResult;
Assert.assertEquals(BinaryString.fromString("2025-01-01"), leafPredicate1.literals().get(0));
// float to double
ConstantOperator floatValue = ConstantOperator.createFloat(11.11);
CastOperator cast2 = new CastOperator(Type.DOUBLE, F2);
Predicate floatResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast2, floatValue));
Assert.assertTrue(floatResult instanceof LeafPredicate);
LeafPredicate leafPredicate2 = (LeafPredicate) floatResult;
Assert.assertEquals(11.11, leafPredicate2.literals().get(0));
// date to string
ConstantOperator date = ConstantOperator.createDate(
LocalDate.parse("2025-01-01").atTime(0, 0, 0, 0));
CastOperator cast3 = new CastOperator(Type.STRING, F3);
Predicate dateResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast3, date));
Assert.assertTrue(dateResult instanceof LeafPredicate);
LeafPredicate leafPredicate3 = (LeafPredicate) dateResult;
Assert.assertEquals(20089, leafPredicate3.literals().get(0));
// bool to string
ConstantOperator bool = ConstantOperator.createBoolean(true);
CastOperator cast4 = new CastOperator(Type.INT, F1);
Predicate stringBoolResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast4, bool));
Assert.assertTrue(stringBoolResult instanceof LeafPredicate);
LeafPredicate leafPredicate4 = (LeafPredicate) stringBoolResult;
Assert.assertEquals(BinaryString.fromString("1"), leafPredicate4.literals().get(0));
// bool to int
ConstantOperator bool2 = ConstantOperator.createBoolean(false);
CastOperator cast5 = new CastOperator(Type.INT, F0);
Predicate intBoolResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast5, bool2));
Assert.assertTrue(intBoolResult instanceof LeafPredicate);
LeafPredicate leafPredicate5 = (LeafPredicate) intBoolResult;
Assert.assertEquals(0, leafPredicate5.literals().get(0));
// datetime to string
ConstantOperator ts = ConstantOperator.createDatetime(
LocalDate.parse("2025-01-01").atTime(0, 0, 0, 0));
CastOperator cast6 = new CastOperator(Type.VARCHAR, F1);
Predicate tsResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast6, ts));
Assert.assertTrue(tsResult instanceof LeafPredicate);
LeafPredicate leafPredicate6 = (LeafPredicate) tsResult;
Assert.assertEquals(BinaryString.fromString("2025-01-01 00:00:00"), leafPredicate6.literals().get(0));
// string to bool
ConstantOperator stringBool = ConstantOperator.createVarchar("false");
CastOperator cast7 = new CastOperator(Type.BOOLEAN, F4);
Predicate boolResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast7, stringBool));
Assert.assertTrue(boolResult instanceof LeafPredicate);
LeafPredicate leafPredicate7 = (LeafPredicate) boolResult;
Assert.assertEquals(false, leafPredicate7.literals().get(0));
// string to datetime
ConstantOperator stringTime = ConstantOperator.createVarchar("2025-01-01 00:00:00");
CastOperator cast8 = new CastOperator(Type.DATETIME, F5);
Predicate dtResult = CONVERTER.convert(new BinaryPredicateOperator(BinaryType.EQ, cast8, stringTime));
Assert.assertTrue(dtResult instanceof LeafPredicate);
LeafPredicate leafPredicate8 = (LeafPredicate) dtResult;
Assert.assertEquals(1735689600000L , ((Timestamp)(leafPredicate8.literals().get(0))).getMillisecond());
}
}

0 comments on commit 0e8a814

Please sign in to comment.