Skip to content

Commit

Permalink
Use assertj for query assertions
Browse files Browse the repository at this point in the history
This makes it possible to write assertions as:

    assertThat(assertions.query("<query>"))
        .matches("<another query>")

and add any other assertions that might be relevant using the power of AssertJ
  • Loading branch information
martint committed Jun 5, 2020
1 parent e7899bf commit a381258
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 64 deletions.
154 changes: 137 additions & 17 deletions presto-main/src/test/java/io/prestosql/sql/query/QueryAssertions.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
*/
package io.prestosql.sql.query;

import com.google.common.base.Joiner;
import com.google.common.collect.Iterables;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.spi.PrestoException;
Expand All @@ -25,12 +23,21 @@
import io.prestosql.testing.MaterializedResult;
import io.prestosql.testing.MaterializedRow;
import io.prestosql.testing.QueryRunner;
import org.assertj.core.api.AbstractAssert;
import org.assertj.core.api.AssertProvider;
import org.assertj.core.api.ListAssert;
import org.assertj.core.api.ThrowableAssert;
import org.assertj.core.presentation.Representation;
import org.assertj.core.presentation.StandardRepresentation;
import org.intellij.lang.annotations.Language;

import java.io.Closeable;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder;
import static io.prestosql.sql.query.QueryAssertions.QueryAssert.newQueryAssert;
import static io.prestosql.testing.TestingSession.testSessionBuilder;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -61,6 +68,33 @@ public QueryAssertions(QueryRunner runner)
this.runner = requireNonNull(runner, "runner is null");
}

public Session.SessionBuilder sessionBuilder()
{
return Session.builder(runner.getDefaultSession());
}

public Session getDefaultSession()
{
return runner.getDefaultSession();
}

public AssertProvider<QueryAssert> query(@Language("SQL") String query)
{
return query(query, runner.getDefaultSession());
}

public AssertProvider<QueryAssert> query(@Language("SQL") String query, Session session)
{
return newQueryAssert(query, runner, session);
}

/**
* @deprecated use {@link org.assertj.core.api.Assertions#assertThatThrownBy(ThrowableAssert.ThrowingCallable)}:
* <pre>
* assertThatThrownBy(() -> assertions.execute(sql))<br>
* .hasMessage(...)
* </pre>
*/
public void assertFails(@Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp)
{
try {
Expand All @@ -86,22 +120,43 @@ public void assertQueryAndPlan(
PlanAssert.assertPlan(runner.getDefaultSession(), runner.getMetadata(), runner.getStatsCalculator(), plan, pattern);
}

/**
* @deprecated use {@link org.assertj.core.api.Assertions#assertThat} with {@link #query(String)}
*/
@Deprecated
public void assertQuery(@Language("SQL") String actual, @Language("SQL") String expected)
{
assertQuery(runner.getDefaultSession(), actual, expected, false);
}

/**
* @deprecated <p>use {@link org.assertj.core.api.Assertions#assertThat} with {@link #query(String, Session)}:
* <pre>
* assertThat(assertions.query(actual, session))<br>
* .matches(expected)
* </pre>
*/
@Deprecated
public void assertQuery(Session session, @Language("SQL") String actual, @Language("SQL") String expected)
{
assertQuery(session, actual, expected, false);
}

/**
* @deprecated use {@link org.assertj.core.api.Assertions#assertThat} with {@link #query(String)}:
* <pre>
* assertThat(assertions.query(actual))<br>
* .ordered()<br>
* .matches(expected)
* </pre>
*/
@Deprecated
public void assertQueryOrdered(@Language("SQL") String actual, @Language("SQL") String expected)
{
assertQuery(runner.getDefaultSession(), actual, expected, true);
}

public void assertQuery(Session session, @Language("SQL") String actual, @Language("SQL") String expected, boolean ensureOrdering)
private void assertQuery(Session session, @Language("SQL") String actual, @Language("SQL") String expected, boolean ensureOrdering)
{
MaterializedResult actualResults = null;
try {
Expand Down Expand Up @@ -147,20 +202,6 @@ public void assertQueryReturnsEmptyResult(@Language("SQL") String actual)
assertEquals(actualRows.size(), 0);
}

public static void assertContains(MaterializedResult all, MaterializedResult expectedSubset)
{
for (MaterializedRow row : expectedSubset.getMaterializedRows()) {
if (!all.getMaterializedRows().contains(row)) {
fail(format("expected row missing: %s\nAll %s rows:\n %s\nExpected subset %s rows:\n %s\n",
row,
all.getMaterializedRows().size(),
Joiner.on("\n ").join(Iterables.limit(all, 100)),
expectedSubset.getMaterializedRows().size(),
Joiner.on("\n ").join(Iterables.limit(expectedSubset, 100))));
}
}
}

public MaterializedResult execute(@Language("SQL") String query)
{
return execute(runner.getDefaultSession(), query);
Expand Down Expand Up @@ -194,4 +235,83 @@ protected void executeExclusively(Runnable executionBlock)
runner.getExclusiveLock().unlock();
}
}

public static class QueryAssert
extends AbstractAssert<QueryAssert, MaterializedResult>
{
private static final Representation ROWS_REPRESENTATION = new StandardRepresentation()
{
@Override
public String toStringOf(Object object)
{
if (object instanceof List) {
List<?> list = (List<?>) object;
return list.stream()
.map(this::toStringOf)
.collect(Collectors.joining(", "));
}
if (object instanceof MaterializedRow) {
MaterializedRow row = (MaterializedRow) object;

return row.getFields().stream()
.map(Object::toString)
.collect(Collectors.joining(", ", "(", ")"));
}
else {
return super.toStringOf(object);
}
}
};

private final QueryRunner runner;
private final Session session;
private boolean ordered;

static AssertProvider<QueryAssert> newQueryAssert(String query, QueryRunner runner, Session session)
{
MaterializedResult result = runner.execute(session, query);
return () -> new QueryAssert(runner, session, result);
}

public QueryAssert(QueryRunner runner, Session session, MaterializedResult actual)
{
super(actual, Object.class);
this.runner = runner;
this.session = session;
}

public QueryAssert matches(BiFunction<Session, QueryRunner, MaterializedResult> evaluator)
{
MaterializedResult expected = evaluator.apply(session, runner);
return isEqualTo(expected);
}

public QueryAssert ordered()
{
ordered = true;
return this;
}

public QueryAssert matches(@Language("SQL") String query)
{
MaterializedResult expected = runner.execute(session, query);

return satisfies(actual -> {
assertThat(actual.getTypes())
.as("Output types")
.isEqualTo(expected.getTypes());

ListAssert<MaterializedRow> assertion = assertThat(actual.getMaterializedRows())
.as("Rows")
.withRepresentation(ROWS_REPRESENTATION);

if (ordered) {
assertion.containsExactlyElementsOf(expected.getMaterializedRows());
}
else {
assertion.containsExactlyInAnyOrderElementsOf(expected.getMaterializedRows());
}
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@

import org.testng.annotations.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class TestAggregationOverJoin
{
@Test
public void test()
{
// https://github.com/prestodb/presto/issues/10592
try (QueryAssertions queryAssertions = new QueryAssertions()) {
queryAssertions
.assertQuery(
try (QueryAssertions assertions = new QueryAssertions()) {
assertThat(assertions.query(
"WITH " +
" t (a, b) AS (VALUES (1, 'a'), (1, 'b')), " +
" u (a) AS (VALUES 1) " +
"SELECT DISTINCT v.a " +
"FROM ( " +
" SELECT DISTINCT a, b " +
" FROM t) v " +
"LEFT JOIN u on v.a = u.a",
"VALUES 1");
"LEFT JOIN u on v.a = u.a"))
.matches("VALUES 1");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class TestCorrelatedJoin
{
private QueryAssertions assertions;
Expand All @@ -37,8 +39,8 @@ public void teardown()
@Test
public void testJoinInCorrelatedJoinInput()
{
assertions.assertQuery(
"SELECT * FROM (VALUES 1) t1(a) JOIN (VALUES 2) t2(b) ON a < b, LATERAL (VALUES 3)",
"VALUES (1, 2, 3)");
assertThat(assertions.query(
"SELECT * FROM (VALUES 1) t1(a) JOIN (VALUES 2) t2(b) ON a < b, LATERAL (VALUES 3)"))
.matches("VALUES (1, 2, 3)");
}
}
29 changes: 12 additions & 17 deletions presto-main/src/test/java/io/prestosql/sql/query/TestFullJoin.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
*/
package io.prestosql.sql.query;

import io.prestosql.testing.MaterializedResult;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import static io.prestosql.sql.query.QueryAssertions.assertContains;
import static org.testng.Assert.assertEquals;
import static org.assertj.core.api.Assertions.assertThat;

public class TestFullJoin
{
Expand All @@ -41,25 +39,22 @@ public void teardown()
@Test
public void testFullJoinWithLimit()
{
MaterializedResult actual = assertions.execute(
"SELECT * FROM (VALUES 1, 2) AS l(v) FULL OUTER JOIN (VALUES 2, 1) AS r(v) ON l.v = r.v LIMIT 1");
assertThat(assertions.query(
"SELECT * FROM (VALUES 1, 2) AS l(v) FULL OUTER JOIN (VALUES 2, 1) AS r(v) ON l.v = r.v LIMIT 1"))
.satisfies(actual -> assertThat(actual.getMaterializedRows())
.hasSize(1)
.containsAnyElementsOf(assertions.execute("VALUES (1,1), (2,2)").getMaterializedRows()));

assertEquals(actual.getMaterializedRows().size(), 1);
assertContains(assertions.execute("VALUES (1,1), (2,2)"), actual);

assertions.assertQuery(
assertThat(assertions.query(
"SELECT * FROM (VALUES 1, 2) AS l(v) FULL OUTER JOIN (VALUES 2) AS r(v) ON l.v = r.v " +
"ORDER BY l.v NULLS FIRST " +
"LIMIT 1",
"VALUES (1, CAST(NULL AS INTEGER))");
"LIMIT 1"))
.matches("VALUES (1, CAST(NULL AS INTEGER))");

assertions.assertQuery(
assertThat(assertions.query(
"SELECT * FROM (VALUES 2) AS l(v) FULL OUTER JOIN (VALUES 1, 2) AS r(v) ON l.v = r.v " +
"ORDER BY r.v NULLS FIRST " +
"LIMIT 1",
"VALUES (CAST(NULL AS INTEGER), 1)");

assertEquals(actual.getMaterializedRows().size(), 1);
assertContains(assertions.execute("VALUES (1,1), (2,2)"), actual);
"LIMIT 1"))
.matches("VALUES (CAST(NULL AS INTEGER), 1)");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class TestJoin
{
private QueryAssertions assertions;
Expand All @@ -37,7 +39,7 @@ public void teardown()
@Test
public void testCrossJoinEliminationWithOuterJoin()
{
assertions.assertQuery(
assertThat(assertions.query(
"WITH " +
" a AS (SELECT id FROM (VALUES (1)) AS t(id))," +
" b AS (SELECT id FROM (VALUES (1)) AS t(id))," +
Expand All @@ -47,7 +49,7 @@ public void testCrossJoinEliminationWithOuterJoin()
"FROM a " +
"LEFT JOIN b ON a.id = b.id " +
"JOIN c ON a.id = CAST(c.id AS bigint) " +
"JOIN d ON d.id = a.id",
"VALUES 1");
"JOIN d ON d.id = a.id"))
.matches("VALUES 1");
}
}
Loading

0 comments on commit a381258

Please sign in to comment.