Skip to content

Commit

Permalink
[verifier] Add option to validate array(float) via error margin.
Browse files Browse the repository at this point in the history
Basically allow verifier to validate array(float) and array(double)
columns in the same way as we validate float and double columns.
Later we can extend this to the maps as well.
  • Loading branch information
spershin committed Mar 25, 2024
1 parent 7d9c85d commit f62eb9a
Show file tree
Hide file tree
Showing 13 changed files with 367 additions and 55 deletions.
6 changes: 6 additions & 0 deletions presto-verifier/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-testng-services</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import javax.annotation.Nullable;

import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.String.format;

public class ArrayColumnChecksum
Expand All @@ -25,12 +27,19 @@ public class ArrayColumnChecksum
private final Object checksum;
private final Object cardinalityChecksum;
private final long cardinalitySum;
// For array(floating point) we have extra aggregations collected.
private final Optional<FloatingPointColumnChecksum> floatingPointChecksum;

public ArrayColumnChecksum(@Nullable Object checksum, @Nullable Object cardinalityChecksum, long cardinalitySum)
public ArrayColumnChecksum(
@Nullable Object checksum,
@Nullable Object cardinalityChecksum,
long cardinalitySum,
Optional<FloatingPointColumnChecksum> floatingPointChecksum)
{
this.checksum = checksum;
this.cardinalityChecksum = cardinalityChecksum;
this.cardinalitySum = cardinalitySum;
this.floatingPointChecksum = floatingPointChecksum;
}

@Nullable
Expand All @@ -52,6 +61,12 @@ public long getCardinalitySum()
return cardinalitySum;
}

public FloatingPointColumnChecksum getFloatingPointChecksum()
{
checkArgument(floatingPointChecksum.isPresent(), "Expect Floating Point Checksum to be present, but it is not");
return floatingPointChecksum.get();
}

@Override
public boolean equals(Object obj)
{
Expand All @@ -63,19 +78,29 @@ public boolean equals(Object obj)
}
ArrayColumnChecksum o = (ArrayColumnChecksum) obj;
return Objects.equals(checksum, o.checksum) &&
Objects.equals(floatingPointChecksum, o.floatingPointChecksum) &&
Objects.equals(cardinalityChecksum, o.cardinalityChecksum) &&
Objects.equals(cardinalitySum, o.cardinalitySum);
}

@Override
public int hashCode()
{
return Objects.hash(checksum, cardinalityChecksum, cardinalitySum);
return Objects.hash(checksum, floatingPointChecksum, cardinalityChecksum, cardinalitySum);
}

@Override
public String toString()
{
return format("checksum: %s, cardinality_checksum: %s, cardinality_sum: %s", checksum, cardinalityChecksum, cardinalitySum);
if (!floatingPointChecksum.isPresent()) {
return format("checksum: %s, cardinality_checksum: %s, cardinality_sum: %s", checksum, cardinalityChecksum, cardinalitySum);
}
else {
return format(
"%s, cardinality_checksum: %s, cardinality_sum: %s",
floatingPointChecksum.get(),
cardinalityChecksum,
cardinalitySum);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,86 @@
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.verifier.framework.Column;
import com.facebook.presto.verifier.framework.VerifierConfig;
import com.google.common.collect.ImmutableList;

import javax.inject.Inject;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.sql.QueryUtil.functionCall;
import static com.facebook.presto.sql.QueryUtil.identifier;
import static com.facebook.presto.verifier.framework.VerifierUtil.delimitedIdentifier;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class ArrayColumnValidator
implements ColumnValidator
{
private final FloatingPointColumnValidator floatingPointValidator;
private final boolean useErrorMarginForFloatingPointArrays;

@Inject
public ArrayColumnValidator(VerifierConfig config, FloatingPointColumnValidator floatingPointValidator)
{
this.floatingPointValidator = requireNonNull(floatingPointValidator, "floatingPointValidator is null");
this.useErrorMarginForFloatingPointArrays = config.isUseErrorMarginForFloatingPointArrays();
}

@Override
public List<SingleColumn> generateChecksumColumns(Column column)
{
Expression checksum = generateArrayChecksum(column.getExpression(), column.getType());
Type columnType = column.getType();
boolean useFloatingPointPath = useFloatingPointPath(column);

// For arrays of floating point numbers we have a different processing, akin to FloatingPointColumnValidator.
if (useFloatingPointPath) {
Type elementType = ((ArrayType) columnType).getElementType();
Expression expression = elementType.equals(DOUBLE) ? column.getExpression() : new Cast(column.getExpression(), new ArrayType(DOUBLE).getDisplayName());

Expression sum = functionCall(
"sum",
functionCall("array_sum", functionCall("filter", expression, generateLambdaExpression("is_finite"))));
Expression nanCount = functionCall(
"sum",
functionCall("cardinality", functionCall("filter", expression, generateLambdaExpression("is_nan"))));
Expression posInfCount = functionCall(
"sum",
functionCall("cardinality", functionCall("filter", expression, generateInfinityLambdaExpression(ArithmeticUnaryExpression.Sign.PLUS))));
Expression negInfCount = functionCall(
"sum",
functionCall("cardinality", functionCall("filter", expression, generateInfinityLambdaExpression(ArithmeticUnaryExpression.Sign.MINUS))));
Expression arrayCardinalityChecksum = functionCall("checksum", functionCall("cardinality", expression));
Expression arrayCardinalitySum = new CoalesceExpression(
functionCall("sum", functionCall("cardinality", expression)), new LongLiteral("0"));
return ImmutableList.of(
new SingleColumn(sum, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getSumColumnAlias(column)))),
new SingleColumn(nanCount, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getNanCountColumnAlias(column)))),
new SingleColumn(posInfCount, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getPositiveInfinityCountColumnAlias(column)))),
new SingleColumn(negInfCount, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getNegativeInfinityCountColumnAlias(column)))),
new SingleColumn(arrayCardinalityChecksum, Optional.of(delimitedIdentifier(getCardinalityChecksumColumnAlias(column)))),
new SingleColumn(arrayCardinalitySum, Optional.of(delimitedIdentifier(getCardinalitySumColumnAlias(column)))));
}

Expression checksum = generateArrayChecksum(column.getExpression(), columnType);
Expression arrayCardinalityChecksum = functionCall("checksum", functionCall("cardinality", column.getExpression()));
Expression arrayCardinalitySum = new CoalesceExpression(
functionCall("sum", functionCall("cardinality", column.getExpression())),
new LongLiteral("0"));

functionCall("sum", functionCall("cardinality", column.getExpression())), new LongLiteral("0"));
return ImmutableList.of(
new SingleColumn(checksum, Optional.of(delimitedIdentifier(getChecksumColumnAlias(column)))),
new SingleColumn(arrayCardinalityChecksum, Optional.of(delimitedIdentifier(getCardinalityChecksumColumnAlias(column)))),
Expand All @@ -55,10 +105,31 @@ public List<SingleColumn> generateChecksumColumns(Column column)
@Override
public List<ColumnMatchResult<ArrayColumnChecksum>> validate(Column column, ChecksumResult controlResult, ChecksumResult testResult)
{
ArrayColumnChecksum controlChecksum = toColumnChecksum(column, controlResult);
ArrayColumnChecksum testChecksum = toColumnChecksum(column, testResult);
checkArgument(
controlResult.getRowCount() == testResult.getRowCount(),
"Test row count (%s) does not match control row count (%s)",
testResult.getRowCount(),
controlResult.getRowCount());

boolean useFloatingPointPath = useFloatingPointPath(column);

ArrayColumnChecksum controlChecksum = toColumnChecksum(column, controlResult, useFloatingPointPath);
ArrayColumnChecksum testChecksum = toColumnChecksum(column, testResult, useFloatingPointPath);

return ImmutableList.of(new ColumnMatchResult<>(Objects.equals(controlChecksum, testChecksum), column, controlChecksum, testChecksum));
// Non-floating point case.
if (!useFloatingPointPath) {
return ImmutableList.of(new ColumnMatchResult<>(Objects.equals(controlChecksum, testChecksum), column, controlChecksum, testChecksum));
}

// Check the non-floating point members first.
if (!Objects.equals(controlChecksum.getCardinalityChecksum(), testChecksum.getCardinalityChecksum()) ||
!Objects.equals(controlChecksum.getCardinalitySum(), testChecksum.getCardinalitySum())) {
return ImmutableList.of(new ColumnMatchResult<>(false, column, Optional.of("cardinality mismatch"), controlChecksum, testChecksum));
}

ColumnMatchResult<FloatingPointColumnChecksum> result =
floatingPointValidator.validate(column, controlChecksum.getFloatingPointChecksum(), testChecksum.getFloatingPointChecksum());
return ImmutableList.of(new ColumnMatchResult<>(result.isMatched(), column, result.getMessage(), controlChecksum, testChecksum));
}

public static Expression generateArrayChecksum(Expression column, Type type)
Expand All @@ -67,24 +138,73 @@ public static Expression generateArrayChecksum(Expression column, Type type)
Type elementType = ((ArrayType) type).getElementType();

if (elementType.isOrderable()) {
FunctionCall arraySort = new FunctionCall(QualifiedName.of("array_sort"), ImmutableList.of(column));
Expression arraySort = functionCall("array_sort", column);

if (elementType instanceof ArrayType || elementType instanceof RowType) {
return new CoalesceExpression(
new FunctionCall(QualifiedName.of("checksum"), ImmutableList.of(new TryExpression(arraySort))),
new FunctionCall(QualifiedName.of("checksum"), ImmutableList.of(column)));
functionCall("checksum", new TryExpression(arraySort)),
functionCall("checksum", column));
}
return new FunctionCall(QualifiedName.of("checksum"), ImmutableList.of(arraySort));
return functionCall("checksum", arraySort);
}
return new FunctionCall(QualifiedName.of("checksum"), ImmutableList.of(column));
return functionCall("checksum", column);
}

private static Expression generateInfinityLambdaExpression(ArithmeticUnaryExpression.Sign sign)
{
ComparisonExpression lambdaBody = new ComparisonExpression(
ComparisonExpression.Operator.EQUAL,
new Identifier("x"),
new ArithmeticUnaryExpression(sign, functionCall("infinity")));
return new LambdaExpression(ImmutableList.of(new LambdaArgumentDeclaration(identifier("x"))), lambdaBody);
}

private static Expression generateLambdaExpression(String functionName)
{
return new LambdaExpression(
ImmutableList.of(new LambdaArgumentDeclaration(identifier("x"))),
functionCall(functionName, new Identifier("x")));
}

private static ArrayColumnChecksum toColumnChecksum(Column column, ChecksumResult checksumResult)
private static ArrayColumnChecksum toColumnChecksum(Column column, ChecksumResult checksumResult, boolean useFloatingPointPath)
{
if (!useFloatingPointPath) {
return new ArrayColumnChecksum(
checksumResult.getChecksum(getChecksumColumnAlias(column)),
checksumResult.getChecksum(getCardinalityChecksumColumnAlias(column)),
(long) checksumResult.getChecksum(getCardinalitySumColumnAlias(column)),
Optional.empty());
}

// Case for an empty result table, when some aggregation return nulls.
Object nanCount = checksumResult.getChecksum(FloatingPointColumnValidator.getNanCountColumnAlias(column));
if (Objects.isNull(nanCount)) {
return new ArrayColumnChecksum(
null,
null,
0,
Optional.of(new FloatingPointColumnChecksum(null, 0, 0, 0, 0)));
}

long cardinalitySum = (long) checksumResult.getChecksum(getCardinalitySumColumnAlias(column));
return new ArrayColumnChecksum(
checksumResult.getChecksum(getChecksumColumnAlias(column)),
null,
checksumResult.getChecksum(getCardinalityChecksumColumnAlias(column)),
(long) checksumResult.getChecksum(getCardinalitySumColumnAlias(column)));
cardinalitySum,
Optional.of(new FloatingPointColumnChecksum(
checksumResult.getChecksum(FloatingPointColumnValidator.getSumColumnAlias(column)),
(long) nanCount,
(long) checksumResult.getChecksum(FloatingPointColumnValidator.getPositiveInfinityCountColumnAlias(column)),
(long) checksumResult.getChecksum(FloatingPointColumnValidator.getNegativeInfinityCountColumnAlias(column)),
cardinalitySum)));
}

private boolean useFloatingPointPath(Column column)
{
Type columnType = column.getType();
checkArgument(columnType instanceof ArrayType, "Expect ArrayType, found %s", columnType.getDisplayName());
Type elementType = ((ArrayType) columnType).getElementType();
return (useErrorMarginForFloatingPointArrays && Column.FLOATING_POINT_TYPES.contains(elementType));
}

private static String getChecksumColumnAlias(Column column)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ public long getNegativeInfinityCount()
return negativeInfinityCount;
}

public long getRowCount()
{
return rowCount;
}

@Override
public boolean equals(Object obj)
{
Expand Down
Loading

0 comments on commit f62eb9a

Please sign in to comment.