Skip to content

Commit

Permalink
[Enhancement] Paimon partition pruning support cast predicate
Browse files Browse the repository at this point in the history
Signed-off-by: Jiao Mingye <mxdzs0612@gmail.com>
  • Loading branch information
mxdzs0612 committed Feb 11, 2025
1 parent 53ab97e commit 8ec0ee3
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

package com.starrocks.connector.paimon;

import com.starrocks.analysis.BoolLiteral;
import com.starrocks.catalog.PrimitiveType;
import com.starrocks.connector.exception.StarRocksConnectorException;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
Expand All @@ -32,8 +34,22 @@
import org.apache.paimon.data.Decimal;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.PredicateBuilder;
import org.apache.paimon.types.BigIntType;
import org.apache.paimon.types.BinaryType;
import org.apache.paimon.types.BooleanType;
import org.apache.paimon.types.CharType;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DateType;
import org.apache.paimon.types.DecimalType;
import org.apache.paimon.types.DoubleType;
import org.apache.paimon.types.FloatType;
import org.apache.paimon.types.IntType;
import org.apache.paimon.types.RowType;
import org.apache.paimon.types.SmallIntType;
import org.apache.paimon.types.TimestampType;
import org.apache.paimon.types.TinyIntType;
import org.apache.paimon.types.VarCharType;

import java.math.BigDecimal;
import java.sql.Timestamp;
Expand All @@ -44,6 +60,7 @@
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import static org.apache.paimon.data.Timestamp.fromSQLTimestamp;
Expand All @@ -52,9 +69,11 @@ public class PaimonPredicateConverter extends ScalarOperatorVisitor<Predicate, V
private static final Logger LOG = LogManager.getLogger(PaimonPredicateConverter.class);
private final PredicateBuilder builder;
private final List<String> fieldNames;
private final List<DataType> fieldTypes;

public PaimonPredicateConverter(RowType rowType) {
this.builder = new PredicateBuilder(rowType);
this.fieldTypes = rowType.getFieldTypes();
this.fieldNames = rowType.getFields().stream().map(DataField::name).collect(Collectors.toList());
}

Expand Down Expand Up @@ -114,10 +133,13 @@ public Predicate visitBinaryPredicate(BinaryPredicateOperator operator, Void con
return null;
}
int idx = fieldNames.indexOf(columnName);
Object literal = getLiteral(operator.getChild(1));
Object literal = getLiteral(operator.getChild(1), fieldTypes.get(idx));
if (literal == null) {
return null;
}
if (fieldTypes.get(idx) instanceof BooleanType) {
literal = convertBoolLiteralValue(literal);
}
switch (operator.getBinaryType()) {
case LT:
return builder.lessThan(idx, literal);
Expand Down Expand Up @@ -146,10 +168,13 @@ public Predicate visitInPredicate(InPredicateOperator operator, Void context) {
List<ScalarOperator> valuesOperatorList = operator.getListChildren();
List<Object> literalValues = new ArrayList<>(valuesOperatorList.size());
for (ScalarOperator valueOperator : valuesOperatorList) {
Object value = getLiteral(valueOperator);
Object value = getLiteral(valueOperator, fieldTypes.get(idx));
if (value == null) {
return null;
}
if (fieldTypes.get(idx) instanceof BooleanType) {
value = convertBoolLiteralValue(value);
}
literalValues.add(value);
}

Expand All @@ -170,7 +195,11 @@ public Predicate visitLikePredicateOperator(LikePredicateOperator operator, Void
int idx = fieldNames.indexOf(columnName);
if (operator.getLikeType() == LikePredicateOperator.LikeType.LIKE) {
if (operator.getChild(1).getType().isStringType()) {
String literal = ((BinaryString) getLiteral(operator.getChild(1))).toString();
Object objectLiteral = getLiteral(operator.getChild(1), fieldTypes.get(idx));
if (objectLiteral == null) {
return null;
}
String literal = ((BinaryString) objectLiteral).toString();
if (literal.length() > 1 && literal.indexOf("%") == literal.length() - 1 && literal.charAt(0) != '%') {
return builder.startsWith(idx, BinaryString.fromString(literal.substring(0, literal.length() - 1)));
}
Expand All @@ -179,48 +208,128 @@ public Predicate visitLikePredicateOperator(LikePredicateOperator operator, Void
return null;
}

private Object getLiteral(ScalarOperator operator) {
if (!(operator instanceof ConstantOperator)) {
private Object getLiteral(ScalarOperator operator, DataType dataType) {
if (operator == null) {
return null;
}

ConstantOperator constValue = (ConstantOperator) operator;
switch (constValue.getType().getPrimitiveType()) {
case BOOLEAN:
return constValue.getBoolean();
case TINYINT:
return constValue.getTinyInt();
case SMALLINT:
return constValue.getSmallint();
case INT:
return constValue.getInt();
case BIGINT:
return constValue.getBigint();
case FLOAT:
return constValue.getFloat();
case DOUBLE:
return constValue.getDouble();
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
BigDecimal bigDecimal = constValue.getDecimal();
PrimitiveType type = constValue.getType().getPrimitiveType();
return Decimal.fromBigDecimal(bigDecimal, PrimitiveType.getMaxPrecisionOfDecimal(type),
PrimitiveType.getDefaultScaleOfDecimal(type));
case HLL:
case VARCHAR:
case CHAR:
return BinaryString.fromString(constValue.getVarchar());
case DATE:
LocalDate localDate = constValue.getDate().toLocalDate();
LocalDate epochDay = Instant.ofEpochSecond(0).atOffset(ZoneOffset.UTC).toLocalDate();
return (int) ChronoUnit.DAYS.between(epochDay, localDate);
case DATETIME:
LocalDateTime localDateTime = constValue.getDatetime();
return fromSQLTimestamp(Timestamp.valueOf((localDateTime)));
default:
return operator.accept(new PaimonPredicateConverter.ExtractLiteralValue(), dataType);
}

private static class ExtractLiteralValue extends ScalarOperatorVisitor<Object, DataType> {
@Override
public Object visit(ScalarOperator scalarOperator, DataType dataType) {
return null;
}

@Override
public Object visitConstant(ConstantOperator operator, DataType dataType) {
if (needCast(operator.getType().getPrimitiveType(), dataType)) {
operator = tryCastToResultType(operator, dataType);
}
if (operator == null) {
return null;
}
switch (operator.getType().getPrimitiveType()) {
case BOOLEAN:
return operator.getBoolean();
case TINYINT:
return operator.getTinyInt();
case SMALLINT:
return operator.getSmallint();
case INT:
return operator.getInt();
case BIGINT:
return operator.getBigint();
case FLOAT:
return operator.getFloat();
case DOUBLE:
return operator.getDouble();
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
BigDecimal bigDecimal = operator.getDecimal();
PrimitiveType type = operator.getType().getPrimitiveType();
return Decimal.fromBigDecimal(bigDecimal, PrimitiveType.getMaxPrecisionOfDecimal(type),
PrimitiveType.getDefaultScaleOfDecimal(type));
case HLL:
case VARCHAR:
case CHAR:
return BinaryString.fromString(operator.getVarchar());
case DATE:
LocalDate localDate = operator.getDate().toLocalDate();
LocalDate epochDay = Instant.ofEpochSecond(0).atOffset(ZoneOffset.UTC).toLocalDate();
return (int) ChronoUnit.DAYS.between(epochDay, localDate);
case DATETIME:
LocalDateTime localDateTime = operator.getDatetime();
return fromSQLTimestamp(Timestamp.valueOf((localDateTime)));
default:
return null;

}
}

@Override
public Object visitCastOperator(CastOperator operator, DataType dataType) {
return operator.getChild(0).accept(this, dataType);
}

private boolean needCast(PrimitiveType sourceType, DataType dataType) {
return switch (sourceType) {
case BOOLEAN -> !(dataType instanceof BooleanType);
case TINYINT -> !(dataType instanceof TinyIntType);
case SMALLINT -> !(dataType instanceof SmallIntType);
case INT -> !(dataType instanceof IntType);
case BIGINT -> !(dataType instanceof BigIntType);
case FLOAT -> !(dataType instanceof FloatType);
case DOUBLE -> !(dataType instanceof DoubleType);
case DECIMALV2, DECIMAL32, DECIMAL64, DECIMAL128 -> !(dataType instanceof DecimalType);
case HLL, VARCHAR -> !(dataType instanceof VarCharType);
case CHAR -> !(dataType instanceof CharType);
case DATE -> !(dataType instanceof DateType);
case DATETIME -> !(dataType instanceof TimestampType);
default -> true;
};
}

private ConstantOperator tryCastToResultType(ConstantOperator operator, DataType dataType) {
Optional<ConstantOperator> res = Optional.empty();
if (dataType instanceof BooleanType) {
res = operator.castTo(com.starrocks.catalog.Type.BOOLEAN);
} else if (dataType instanceof DateType) {
res = operator.castTo(com.starrocks.catalog.Type.DATE);
} else if (dataType instanceof TimestampType) {
res = operator.castTo(com.starrocks.catalog.Type.DATETIME);
} else if (dataType instanceof VarCharType) {
res = operator.castTo(com.starrocks.catalog.Type.STRING);
} else if (dataType instanceof CharType) {
res = operator.castTo(com.starrocks.catalog.Type.CHAR);
} else if (dataType instanceof BinaryType) {
res = operator.castTo(com.starrocks.catalog.Type.VARBINARY);
} else if (dataType instanceof IntType) {
res = operator.castTo(com.starrocks.catalog.Type.INT);
} else if (dataType instanceof BigIntType) {
res = operator.castTo(com.starrocks.catalog.Type.BIGINT);
} else if (dataType instanceof TinyIntType) {
res = operator.castTo(com.starrocks.catalog.Type.TINYINT);
} else if (dataType instanceof SmallIntType) {
res = operator.castTo(com.starrocks.catalog.Type.SMALLINT);
} else if (dataType instanceof FloatType) {
res = operator.castTo(com.starrocks.catalog.Type.FLOAT);
} else if (dataType instanceof DoubleType) {
res = operator.castTo(com.starrocks.catalog.Type.DOUBLE);
}
return res.orElse(operator);
}
}

// Support both 0/1 and true/false for boolean type
private static Object convertBoolLiteralValue(Object literalValue) {
try {
return new BoolLiteral(String.valueOf(literalValue)).getValue();
} catch (Exception e) {
throw new StarRocksConnectorException("Failed to convert %s to boolean type", literalValue);
}
}

Expand Down
Loading

0 comments on commit 8ec0ee3

Please sign in to comment.