From e4cca58f0e137901c1cc3f2f94211ccf60548b95 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Fri, 27 Dec 2024 13:07:39 -0500 Subject: [PATCH] ESQL: Compute infrastruture for LEFT JOIN (#118889) This adds some infrastructure that we can use to run LOOKUP JOIN using real LEFT JOIN semantics. Right now if LOOKUP JOIN matches many rows in the `lookup` index we merge all of the values into a multivalued field. So the number of rows emitted from LOOKUP JOIN is the same as the number of rows that comes into LOOKUP JOIN. This change builds the infrastructure to emit one row per match, mostly reusing the infrastructure from ENRICH. --- .../compute/data/BooleanArrayBlock.java | 2 +- .../compute/data/BytesRefArrayBlock.java | 2 +- .../compute/data/DoubleArrayBlock.java | 2 +- .../compute/data/FloatArrayBlock.java | 2 +- .../compute/data/IntArrayBlock.java | 2 +- .../compute/data/LongArrayBlock.java | 2 +- .../org/elasticsearch/compute/data/Block.java | 38 +- .../compute/data/X-ArrayBlock.java.st | 2 +- .../lookup/MergePositionsOperator.java | 24 +- .../operator/lookup/RightChunkedLeftJoin.java | 253 ++++++++++ .../compute/data/BasicBlockTests.java | 64 +++ .../compute/data/BlockMultiValuedTests.java | 12 + .../compute/operator/ComputeTestCase.java | 11 + .../compute/operator/LimitOperatorTests.java | 16 +- .../lookup/RightChunkedLeftJoinTests.java | 434 ++++++++++++++++++ .../xpack/esql/action/LookupFromIndexIT.java | 63 ++- 16 files changed, 883 insertions(+), 46 deletions(-) create mode 100644 x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoin.java create mode 100644 x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoinTests.java diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java index 3d600bec1bd65..896662dddf1eb 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BooleanArrayBlock.java @@ -122,7 +122,7 @@ public BooleanBlock filter(int... positions) { int valueCount = getValueCount(pos); int first = getFirstValueIndex(pos); if (valueCount == 1) { - builder.appendBoolean(getBoolean(getFirstValueIndex(pos))); + builder.appendBoolean(getBoolean(first)); } else { builder.beginPositionEntry(); for (int c = 0; c < valueCount; c++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java index f50135aa51dd4..5bcb1b0ec5095 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/BytesRefArrayBlock.java @@ -110,7 +110,7 @@ public BytesRefBlock filter(int... positions) { int valueCount = getValueCount(pos); int first = getFirstValueIndex(pos); if (valueCount == 1) { - builder.appendBytesRef(getBytesRef(getFirstValueIndex(pos), scratch)); + builder.appendBytesRef(getBytesRef(first, scratch)); } else { builder.beginPositionEntry(); for (int c = 0; c < valueCount; c++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java index eceec30348749..20bd42da98c71 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/DoubleArrayBlock.java @@ -101,7 +101,7 @@ public DoubleBlock filter(int... positions) { int valueCount = getValueCount(pos); int first = getFirstValueIndex(pos); if (valueCount == 1) { - builder.appendDouble(getDouble(getFirstValueIndex(pos))); + builder.appendDouble(getDouble(first)); } else { builder.beginPositionEntry(); for (int c = 0; c < valueCount; c++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java index 56f0cedb5f15e..c0941557dc4fe 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/FloatArrayBlock.java @@ -101,7 +101,7 @@ public FloatBlock filter(int... positions) { int valueCount = getValueCount(pos); int first = getFirstValueIndex(pos); if (valueCount == 1) { - builder.appendFloat(getFloat(getFirstValueIndex(pos))); + builder.appendFloat(getFloat(first)); } else { builder.beginPositionEntry(); for (int c = 0; c < valueCount; c++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java index 2e10d09e1a410..8ced678bc90b0 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/IntArrayBlock.java @@ -101,7 +101,7 @@ public IntBlock filter(int... positions) { int valueCount = getValueCount(pos); int first = getFirstValueIndex(pos); if (valueCount == 1) { - builder.appendInt(getInt(getFirstValueIndex(pos))); + builder.appendInt(getInt(first)); } else { builder.beginPositionEntry(); for (int c = 0; c < valueCount; c++) { diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java index 776fa363f6080..fb631ab326ce7 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/data/LongArrayBlock.java @@ -101,7 +101,7 @@ public LongBlock filter(int... positions) { int valueCount = getValueCount(pos); int first = getFirstValueIndex(pos); if (valueCount == 1) { - builder.appendLong(getLong(getFirstValueIndex(pos))); + builder.appendLong(getLong(first)); } else { builder.beginPositionEntry(); for (int c = 0; c < valueCount; c++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Block.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Block.java index 1e06cf1ea4450..edf54a829deba 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Block.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/Block.java @@ -212,10 +212,46 @@ default boolean mvSortedAscending() { /** * Expand multivalued fields into one row per value. Returns the same block if there aren't any multivalued * fields to expand. The returned block needs to be closed by the caller to release the block's resources. - * TODO: pass BlockFactory */ Block expand(); + /** + * Build a {@link Block} with a {@code null} inserted {@code before} each + * listed position. + *

+ * Note: {@code before} must be non-decreasing. + *

+ */ + default Block insertNulls(IntVector before) { + // TODO remove default and scatter to implementation where it can be a lot more efficient + int myCount = getPositionCount(); + int beforeCount = before.getPositionCount(); + try (Builder builder = elementType().newBlockBuilder(myCount + beforeCount, blockFactory())) { + int beforeP = 0; + int nextNull = before.getInt(beforeP); + for (int mainP = 0; mainP < myCount; mainP++) { + while (mainP == nextNull) { + builder.appendNull(); + beforeP++; + if (beforeP >= beforeCount) { + builder.copyFrom(this, mainP, myCount); + return builder.build(); + } + nextNull = before.getInt(beforeP); + } + // This line right below this is the super inefficient one. + builder.copyFrom(this, mainP, mainP + 1); + } + assert nextNull == myCount; + while (beforeP < beforeCount) { + nextNull = before.getInt(beforeP++); + assert nextNull == myCount; + builder.appendNull(); + } + return builder.build(); + } + } + /** * Builds {@link Block}s. Typically, you use one of it's direct supinterfaces like {@link IntBlock.Builder}. * This is {@link Releasable} and should be released after building the block or if building the block fails. diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st index e855e6d6296d8..16e2a62b9d030 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/X-ArrayBlock.java.st @@ -149,7 +149,7 @@ $endif$ int valueCount = getValueCount(pos); int first = getFirstValueIndex(pos); if (valueCount == 1) { - builder.append$Type$(get$Type$(getFirstValueIndex(pos)$if(BytesRef)$, scratch$endif$)); + builder.append$Type$(get$Type$(first$if(BytesRef)$, scratch$endif$)); } else { builder.beginPositionEntry(); for (int c = 0; c < valueCount; c++) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/MergePositionsOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/MergePositionsOperator.java index d42655446ca10..578e1c046954b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/MergePositionsOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/MergePositionsOperator.java @@ -20,21 +20,24 @@ import java.util.Objects; /** - * Combines values at the given blocks with the same positions into a single position for the blocks at the given channels + * Combines values at the given blocks with the same positions into a single position + * for the blocks at the given channels. + *

* Example, input pages consisting of three blocks: - * positions | field-1 | field-2 | - * ----------------------------------- + *

+ *
{@code
+ * | positions    | field-1 | field-2 |
+ * ------------------------------------
  * Page 1:
- * 1           |  a,b    |   2020  |
- * 1           |  c      |   2021  |
- * ---------------------------------
+ * | 1            |  a,b    |   2020  |
+ * | 1            |  c      |   2021  |
  * Page 2:
- * 2           |  a,e    |   2021  |
- * ---------------------------------
+ * | 2            |  a,e    |   2021  |
  * Page 3:
- * 4           |  d      |   null  |
- * ---------------------------------
+ * | 4            |  d      |   null  |
+ * }
* Output: + *
{@code
  * |  field-1   | field-2    |
  * ---------------------------
  * |  null      | null       |
@@ -42,6 +45,7 @@
  * |  a,e       | 2021       |
  * |  null      | null       |
  * |  d         | 2023       |
+ * }
*/ public final class MergePositionsOperator implements Operator { private boolean finished = false; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoin.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoin.java new file mode 100644 index 0000000000000..f9895ff346b5c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoin.java @@ -0,0 +1,253 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.lookup; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.Releasables; + +import java.util.Optional; +import java.util.stream.IntStream; + +/** + * Performs a {@code LEFT JOIN} where many "right hand" pages are joined + * against a "left hand" {@link Page}. Each row on the "left hand" page + * is output at least once whether it appears in the "right hand" or not. + * And more than once if it appears in the "right hand" pages more than once. + *

+ * The "right hand" page contains a non-decreasing {@code positions} + * column that controls which position in the "left hand" page the row + * in the "right hand" page. This'll make more sense with a picture: + *

+ *
{@code
+ * "left hand"                 "right hand"
+ * | lhdata |             | positions | r1 | r2 |
+ * ----------             -----------------------
+ * |    l00 |             |         0 |  1 |  2 |
+ * |    l01 |             |         1 |  2 |  3 |
+ * |    l02 |             |         1 |  3 |  3 |
+ * |    ... |             |         3 |  9 |  9 |
+ * |    l99 |
+ * }
+ *

+ * Joins to: + *

+ *
{@code
+ * | lhdata  |  r1  |  r2  |
+ * -------------------------
+ * |     l00 |    1 |    2 |
+ * |     l01 |    2 |    3 |
+ * |     l01 |    3 |    3 |   <1>
+ * |     l02 | null | null |   <2>
+ * |     l03 |    9 |    9 |
+ * }
+ *
    + *
  1. {@code l01} is duplicated because it's positions appears twice in + * the right hand page.
  2. + *
  3. {@code l02}'s row is filled with {@code null}s because it's position + * does not appear in the right hand page.
  4. + *
+ *

+ * This supports joining many "right hand" pages against the same + * "left hand" so long as the first value of the next {@code positions} + * column is the same or greater than the last value of the previous + * {@code positions} column. Large gaps are fine. Starting with the + * same number as you ended on is fine. This looks like: + *

+ *
{@code
+ * "left hand"                 "right hand"
+ * | lhdata |             | positions | r1 | r2 |
+ * ----------             -----------------------
+ * |    l00 |                    page 1
+ * |    l01 |             |         0 |  1 |  2 |
+ * |    l02 |             |         1 |  3 |  3 |
+ * |    l03 |                    page 2
+ * |    l04 |             |         1 |  9 |  9 |
+ * |    l05 |             |         2 |  9 |  9 |
+ * |    l06 |                    page 3
+ * |    ... |             |         5 | 10 | 10 |
+ * |    l99 |             |         7 | 11 | 11 |
+ * }
+ *

+ * Which makes: + *

+ *
{@code
+ * | lhdata  |  r1  |  r2  |
+ * -------------------------
+ *         page 1
+ * |     l00 |    1 |    2 |
+ * |     l01 |    3 |    3 |
+ *         page 2
+ * |     l01 |    9 |    9 |
+ * |     l02 |    9 |    9 |
+ *         page 3
+ * |     l03 | null | null |
+ * |     l04 | null | null |
+ * |     l05 |   10 |   10 |
+ * |     l06 | null | null |
+ * |     l07 |   11 |   11 |
+ * }
+ *

+ * Note that the output pages are sized by the "right hand" pages with + * {@code null}s inserted. + *

+ *

+ * Finally, after all "right hand" pages have been joined this will produce + * all remaining "left hand" rows joined against {@code null}. + * Another picture: + *

+ *
{@code
+ * "left hand"                 "right hand"
+ * | lhdata |             | positions | r1 | r2 |
+ * ----------             -----------------------
+ * |    l00 |                    last page
+ * |    l01 |             |        96 |  1 |  2 |
+ * |    ... |             |        97 |  1 |  2 |
+ * |    l99 |
+ * }
+ *

+ * Which makes: + *

+ *
{@code
+ * | lhdata  |  r1  |  r2  |
+ * -------------------------
+ *     last matching page
+ * |     l96 |    1 |    2 |
+ * |     l97 |    2 |    3 |
+ *    trailing nulls page
+ * |     l98 | null | null |
+ * |     l99 | null | null |
+ * }
+ */ +class RightChunkedLeftJoin implements Releasable { + private final Page leftHand; + private final int mergedElementCount; + /** + * The next position that we'll emit or one more than the + * next position we'll emit. This is used to cover gaps between "right hand" + * pages and to detect if "right hand" pages "go backwards". + */ + private int next = 0; + + RightChunkedLeftJoin(Page leftHand, int mergedElementCounts) { + this.leftHand = leftHand; + this.mergedElementCount = mergedElementCounts; + } + + Page join(Page rightHand) { + IntVector positions = rightHand.getBlock(0).asVector(); + if (positions.getInt(0) < next - 1) { + throw new IllegalArgumentException("maximum overlap is one position"); + } + Block[] blocks = new Block[leftHand.getBlockCount() + mergedElementCount]; + if (rightHand.getBlockCount() != mergedElementCount + 1) { + throw new IllegalArgumentException( + "expected right hand side with [" + (mergedElementCount + 1) + "] but got [" + rightHand.getBlockCount() + "]" + ); + } + IntVector.Builder leftFilterBuilder = null; + IntVector leftFilter = null; + IntVector.Builder insertNullsBuilder = null; + IntVector insertNulls = null; + try { + leftFilterBuilder = positions.blockFactory().newIntVectorBuilder(positions.getPositionCount()); + for (int p = 0; p < positions.getPositionCount(); p++) { + int pos = positions.getInt(p); + if (pos > next) { + if (insertNullsBuilder == null) { + insertNullsBuilder = positions.blockFactory().newIntVectorBuilder(pos - next); + } + for (int missing = next; missing < pos; missing++) { + leftFilterBuilder.appendInt(missing); + insertNullsBuilder.appendInt(p); + } + } + leftFilterBuilder.appendInt(pos); + next = pos + 1; + } + leftFilter = leftFilterBuilder.build(); + int[] leftFilterArray = toArray(leftFilter); + insertNulls = insertNullsBuilder == null ? null : insertNullsBuilder.build(); + + int b = 0; + while (b < leftHand.getBlockCount()) { + blocks[b] = leftHand.getBlock(b).filter(leftFilterArray); + b++; + } + int rb = 1; // Skip the positions column + while (b < blocks.length) { + Block block = rightHand.getBlock(rb); + if (insertNulls == null) { + block.mustIncRef(); + } else { + block = block.insertNulls(insertNulls); + } + blocks[b] = block; + b++; + rb++; + } + Page result = new Page(blocks); + blocks = null; + return result; + } finally { + Releasables.close( + blocks == null ? null : Releasables.wrap(blocks), + leftFilter, + leftFilterBuilder, + insertNullsBuilder, + insertNulls + ); + } + } + + Optional noMoreRightHandPages() { + if (next == leftHand.getPositionCount()) { + return Optional.empty(); + } + BlockFactory factory = leftHand.getBlock(0).blockFactory(); + Block[] blocks = new Block[leftHand.getBlockCount() + mergedElementCount]; + // TODO make a filter that takes a min and max? + int[] filter = IntStream.range(next, leftHand.getPositionCount()).toArray(); + try { + int b = 0; + while (b < leftHand.getBlockCount()) { + blocks[b] = leftHand.getBlock(b).filter(filter); + b++; + } + while (b < blocks.length) { + blocks[b] = factory.newConstantNullBlock(leftHand.getPositionCount() - next); + b++; + } + Page result = new Page(blocks); + blocks = null; + return Optional.of(result); + } finally { + if (blocks != null) { + Releasables.close(blocks); + } + } + } + + @Override + public void close() { + Releasables.close(leftHand::releaseBlocks); + } + + private int[] toArray(IntVector vector) { + // TODO replace parameter to filter with vector and remove this + int[] array = new int[vector.getPositionCount()]; + for (int p = 0; p < vector.getPositionCount(); p++) { + array[p] = vector.getInt(p); + } + return array; + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java index 439ebe34c7d4a..33a294131c996 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BasicBlockTests.java @@ -224,6 +224,7 @@ public void testIntBlock() { try (IntBlock.Builder blockBuilder = blockFactory.newIntBlockBuilder(1)) { IntBlock copy = blockBuilder.copyFrom(block, 0, block.getPositionCount()).build(); assertThat(copy, equalTo(block)); + assertInsertNulls(block); releaseAndAssertBreaker(block, copy); } @@ -250,6 +251,7 @@ public void testIntBlock() { assertSingleValueDenseBlock(vector.asBlock()); assertThat(vector.min(), equalTo(0)); assertThat(vector.max(), equalTo(positionCount - 1)); + assertInsertNulls(vector.asBlock()); releaseAndAssertBreaker(vector.asBlock()); } } @@ -272,12 +274,14 @@ public void testIntBlockEmpty() { assertEmptyLookup(blockFactory, block); assertThat(block.asVector().min(), equalTo(Integer.MAX_VALUE)); assertThat(block.asVector().max(), equalTo(Integer.MIN_VALUE)); + assertInsertNulls(block); releaseAndAssertBreaker(block); try (IntVector.Builder vectorBuilder = blockFactory.newIntVectorBuilder(0)) { IntVector vector = vectorBuilder.build(); assertThat(vector.min(), equalTo(Integer.MAX_VALUE)); assertThat(vector.max(), equalTo(Integer.MIN_VALUE)); + assertInsertNulls(vector.asBlock()); releaseAndAssertBreaker(vector.asBlock()); } } @@ -317,6 +321,7 @@ public void testConstantIntBlock() { assertEmptyLookup(blockFactory, block); assertThat(block.asVector().min(), equalTo(value)); assertThat(block.asVector().max(), equalTo(value)); + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -350,6 +355,7 @@ public void testLongBlock() { try (LongBlock.Builder blockBuilder = blockFactory.newLongBlockBuilder(1)) { LongBlock copy = blockBuilder.copyFrom(block, 0, block.getPositionCount()).build(); assertThat(copy, equalTo(block)); + assertInsertNulls(block); releaseAndAssertBreaker(block, copy); } @@ -372,6 +378,7 @@ public void testLongBlock() { LongStream.range(0, positionCount).forEach(vectorBuilder::appendLong); LongVector vector = vectorBuilder.build(); assertSingleValueDenseBlock(vector.asBlock()); + assertInsertNulls(vector.asBlock()); releaseAndAssertBreaker(vector.asBlock()); } } @@ -408,6 +415,7 @@ public void testConstantLongBlock() { b -> assertThat(b, instanceOf(ConstantNullBlock.class)) ); assertEmptyLookup(blockFactory, block); + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -442,6 +450,7 @@ public void testDoubleBlock() { try (DoubleBlock.Builder blockBuilder = blockFactory.newDoubleBlockBuilder(1)) { DoubleBlock copy = blockBuilder.copyFrom(block, 0, block.getPositionCount()).build(); assertThat(copy, equalTo(block)); + assertInsertNulls(block); releaseAndAssertBreaker(block, copy); } @@ -466,6 +475,7 @@ public void testDoubleBlock() { IntStream.range(0, positionCount).mapToDouble(ii -> 1.0 / ii).forEach(vectorBuilder::appendDouble); DoubleVector vector = vectorBuilder.build(); assertSingleValueDenseBlock(vector.asBlock()); + assertInsertNulls(vector.asBlock()); releaseAndAssertBreaker(vector.asBlock()); } } @@ -501,6 +511,7 @@ public void testConstantDoubleBlock() { b -> assertThat(b, instanceOf(ConstantNullBlock.class)) ); assertEmptyLookup(blockFactory, block); + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -536,6 +547,7 @@ public void testFloatBlock() { try (FloatBlock.Builder blockBuilder = blockFactory.newFloatBlockBuilder(1)) { FloatBlock copy = blockBuilder.copyFrom(block, 0, block.getPositionCount()).build(); assertThat(copy, equalTo(block)); + assertInsertNulls(block); releaseAndAssertBreaker(block, copy); } @@ -560,6 +572,7 @@ public void testFloatBlock() { IntStream.range(0, positionCount).mapToDouble(ii -> 1.0 / ii).forEach(vectorBuilder::appendDouble); DoubleVector vector = vectorBuilder.build(); assertSingleValueDenseBlock(vector.asBlock()); + assertInsertNulls(vector.asBlock()); releaseAndAssertBreaker(vector.asBlock()); } } @@ -595,6 +608,7 @@ public void testConstantFloatBlock() { b -> assertThat(b, instanceOf(ConstantNullBlock.class)) ); assertEmptyLookup(blockFactory, block); + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -646,6 +660,7 @@ private void testBytesRefBlock(Supplier byteArraySupplier, boolean cho try (BytesRefBlock.Builder blockBuilder = blockFactory.newBytesRefBlockBuilder(1)) { BytesRefBlock copy = blockBuilder.copyFrom(block, 0, block.getPositionCount()).build(); assertThat(copy, equalTo(block)); + assertInsertNulls(block); releaseAndAssertBreaker(block, copy); } @@ -671,6 +686,7 @@ private void testBytesRefBlock(Supplier byteArraySupplier, boolean cho IntStream.range(0, positionCount).mapToObj(ii -> new BytesRef(randomAlphaOfLength(5))).forEach(vectorBuilder::appendBytesRef); BytesRefVector vector = vectorBuilder.build(); assertSingleValueDenseBlock(vector.asBlock()); + assertInsertNulls(vector.asBlock()); releaseAndAssertBreaker(vector.asBlock()); } } @@ -726,6 +742,7 @@ public void testBytesRefBlockBuilderWithNulls() { } } assertKeepMask(block); + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -765,6 +782,7 @@ public void testConstantBytesRefBlock() { b -> assertThat(b, instanceOf(ConstantNullBlock.class)) ); assertEmptyLookup(blockFactory, block); + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -810,6 +828,7 @@ public void testBooleanBlock() { try (BooleanBlock.Builder blockBuilder = blockFactory.newBooleanBlockBuilder(1)) { BooleanBlock copy = blockBuilder.copyFrom(block, 0, block.getPositionCount()).build(); assertThat(copy, equalTo(block)); + assertInsertNulls(block); releaseAndAssertBreaker(block, copy); } @@ -852,6 +871,7 @@ public void testBooleanBlock() { assertTrue(vector.allFalse()); } } + assertInsertNulls(vector.asBlock()); releaseAndAssertBreaker(vector.asBlock()); } } @@ -893,6 +913,7 @@ public void testConstantBooleanBlock() { assertFalse(block.asVector().allTrue()); assertTrue(block.asVector().allFalse()); } + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -935,6 +956,7 @@ public void testConstantNullBlock() { singletonList(null), b -> assertThat(b, instanceOf(ConstantNullBlock.class)) ); + assertInsertNulls(block); releaseAndAssertBreaker(block); } } @@ -1390,6 +1412,7 @@ void assertNullValues( asserter.accept(randomNonNullPosition, block); assertTrue(block.isNull(randomNullPosition)); assertFalse(block.isNull(randomNonNullPosition)); + assertInsertNulls(block); releaseAndAssertBreaker(block); if (block instanceof BooleanBlock bb) { try (ToMask mask = bb.toMask()) { @@ -1409,6 +1432,7 @@ void assertZeroPositionsAndRelease(BooleanBlock block) { void assertZeroPositionsAndRelease(Block block) { assertThat(block.getPositionCount(), is(0)); assertKeepMaskEmpty(block); + assertInsertNulls(block); releaseAndAssertBreaker(block); } @@ -1451,6 +1475,36 @@ static void assertToMask(BooleanVector vector) { } } + static void assertInsertNulls(Block block) { + int maxNulls = Math.min(1000, block.getPositionCount() * 5); + List orig = new ArrayList<>(block.getPositionCount()); + BlockTestUtils.readInto(orig, block); + + int nullCount = 0; + try (IntVector.Builder beforeBuilder = block.blockFactory().newIntVectorBuilder(block.getPositionCount())) { + List expected = new ArrayList<>(block.getPositionCount()); + for (int p = 0; p < block.getPositionCount(); p++) { + while (nullCount < maxNulls && randomBoolean()) { + expected.add(null); + beforeBuilder.appendInt(p); + nullCount++; + } + expected.add(orig.get(p)); + } + while (nullCount == 0 || (nullCount < maxNulls && randomBoolean())) { + expected.add(null); + beforeBuilder.appendInt(block.getPositionCount()); + nullCount++; + } + + try (IntVector before = beforeBuilder.build(); Block withNulls = block.insertNulls(before)) { + List actual = new ArrayList<>(block.getPositionCount()); + BlockTestUtils.readInto(actual, withNulls); + assertThat(actual, equalTo(expected)); + } + } + } + void releaseAndAssertBreaker(Block... blocks) { assertThat(breaker.getUsed(), greaterThan(0L)); Page[] pages = Arrays.stream(blocks).map(Page::new).toArray(Page[]::new); @@ -1909,4 +1963,14 @@ static BooleanVector randomMask(int positions) { return builder.build(); } } + + /** + * A random {@link ElementType} for which we can build a {@link RandomBlock}. + */ + public static ElementType randomElementType() { + return randomValueOtherThanMany( + e -> e == ElementType.UNKNOWN || e == ElementType.NULL || e == ElementType.DOC || e == ElementType.COMPOSITE, + () -> randomFrom(ElementType.values()) + ); + } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockMultiValuedTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockMultiValuedTests.java index e37b2638b56f7..da7b8cd87db7d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockMultiValuedTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockMultiValuedTests.java @@ -29,6 +29,7 @@ import java.util.function.IntUnaryOperator; import java.util.stream.IntStream; +import static org.elasticsearch.compute.data.BasicBlockTests.assertInsertNulls; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.nullValue; @@ -68,6 +69,7 @@ public void testMultiValued() { assertThat(b.block().mayHaveMultivaluedFields(), equalTo(b.values().stream().anyMatch(l -> l != null && l.size() > 1))); assertThat(b.block().doesHaveMultivaluedFields(), equalTo(b.values().stream().anyMatch(l -> l != null && l.size() > 1))); + assertInsertNulls(b.block()); } finally { b.block().close(); } @@ -171,6 +173,16 @@ public void testMask() { } } + public void testInsertNull() { + int positionCount = randomIntBetween(1, 16 * 1024); + var b = BasicBlockTests.randomBlock(blockFactory(), elementType, positionCount, nullAllowed, 2, 10, 0, 0); + try { + assertInsertNulls(b.block()); + } finally { + b.block().close(); + } + } + private void assertFiltered(boolean all, boolean shuffled) { int positionCount = randomIntBetween(1, 16 * 1024); var b = BasicBlockTests.randomBlock(blockFactory(), elementType, positionCount, nullAllowed, 0, 10, 0, 0); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ComputeTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ComputeTestCase.java index ce62fb9896eba..cf99c59bb4c71 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ComputeTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ComputeTestCase.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; import static org.hamcrest.Matchers.equalTo; @@ -76,6 +77,16 @@ protected final BlockFactory crankyBlockFactory() { return blockFactory; } + protected final void testWithCrankyBlockFactory(Consumer run) { + try { + run.accept(crankyBlockFactory()); + logger.info("cranky let us finish!"); + } catch (CircuitBreakingException e) { + logger.info("cranky", e); + assertThat(e.getMessage(), equalTo(CrankyCircuitBreakerService.ERROR_MESSAGE)); + } + } + @After public final void allBreakersEmpty() throws Exception { // first check that all big arrays are released, which can affect breakers diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java index 8200529e18290..bbe4a07cc44bd 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/LimitOperatorTests.java @@ -10,14 +10,13 @@ import org.elasticsearch.compute.data.BasicBlockTests; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; -import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.data.Page; import org.hamcrest.Matcher; -import java.util.ArrayList; import java.util.List; import java.util.stream.LongStream; +import static org.elasticsearch.compute.data.BasicBlockTests.randomElementType; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.sameInstance; @@ -129,17 +128,6 @@ Block randomBlock(BlockFactory blockFactory, int size) { if (randomBoolean()) { return blockFactory.newConstantNullBlock(size); } - return BasicBlockTests.randomBlock(blockFactory, randomElement(), size, false, 1, 1, 0, 0).block(); - } - - static ElementType randomElement() { - List l = new ArrayList<>(); - for (ElementType e : ElementType.values()) { - if (e == ElementType.UNKNOWN || e == ElementType.NULL || e == ElementType.DOC || e == ElementType.COMPOSITE) { - continue; - } - l.add(e); - } - return randomFrom(l); + return BasicBlockTests.randomBlock(blockFactory, randomElementType(), size, false, 1, 1, 0, 0).block(); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoinTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoinTests.java new file mode 100644 index 0000000000000..1312b772dbfa1 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/lookup/RightChunkedLeftJoinTests.java @@ -0,0 +1,434 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator.lookup; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.compute.data.BasicBlockTests; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BlockTestUtils; +import org.elasticsearch.compute.data.BytesRefVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.ComputeTestCase; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.test.ListMatcher; + +import java.text.NumberFormat; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.elasticsearch.test.ListMatcher.matchesList; +import static org.elasticsearch.test.MapMatcher.assertMap; +import static org.hamcrest.Matchers.equalTo; + +public class RightChunkedLeftJoinTests extends ComputeTestCase { + public void testNoGaps() { + testNoGaps(blockFactory()); + } + + public void testNoGapsCranky() { + testWithCrankyBlockFactory(this::testNoGaps); + } + + private void testNoGaps(BlockFactory factory) { + int size = 100; + try (RightChunkedLeftJoin join = new RightChunkedLeftJoin(buildExampleLeftHand(factory, size), 2)) { + assertJoined( + factory, + join, + new int[][] { + { 0, 1, 2 }, // formatter + { 1, 2, 3 }, // formatter + { 2, 3, 3 }, // formatter + { 3, 9, 9 }, // formatter + }, + new Object[][] { + { "l00", 1, 2 }, // formatter + { "l01", 2, 3 }, // formatter + { "l02", 3, 3 }, // formatter + { "l03", 9, 9 }, // formatter + } + ); + assertTrailing(join, size, 4); + } + } + + /** + * Test the first example in the main javadoc of {@link RightChunkedLeftJoin}. + */ + public void testFirstExample() { + testFirstExample(blockFactory()); + } + + public void testFirstExampleCranky() { + testWithCrankyBlockFactory(this::testFirstExample); + } + + private void testFirstExample(BlockFactory factory) { + try (RightChunkedLeftJoin join = new RightChunkedLeftJoin(buildExampleLeftHand(factory, 100), 2)) { + assertJoined( + factory, + join, + new int[][] { + { 0, 1, 2 }, // formatter + { 1, 2, 3 }, // formatter + { 1, 3, 3 }, // formatter + { 3, 9, 9 }, // formatter + }, + new Object[][] { + { "l00", 1, 2 }, // formatter + { "l01", 2, 3 }, // formatter + { "l01", 3, 3 }, // formatter + { "l02", null, null }, // formatter + { "l03", 9, 9 }, // formatter + } + ); + } + } + + public void testLeadingNulls() { + testLeadingNulls(blockFactory()); + } + + public void testLeadingNullsCranky() { + testWithCrankyBlockFactory(this::testLeadingNulls); + } + + private void testLeadingNulls(BlockFactory factory) { + int size = 3; + try (RightChunkedLeftJoin join = new RightChunkedLeftJoin(buildExampleLeftHand(factory, size), 2)) { + assertJoined( + factory, + join, + new int[][] { { 2, 1, 2 } }, + new Object[][] { + { "l0", null, null }, // formatter + { "l1", null, null }, // formatter + { "l2", 1, 2 }, // formatter + } + ); + assertTrailing(join, size, 3); + } + } + + public void testSecondExample() { + testSecondExample(blockFactory()); + } + + public void testSecondExampleCranky() { + testWithCrankyBlockFactory(this::testSecondExample); + } + + /** + * Test the second example in the main javadoc of {@link RightChunkedLeftJoin}. + */ + private void testSecondExample(BlockFactory factory) { + int size = 100; + try (RightChunkedLeftJoin join = new RightChunkedLeftJoin(buildExampleLeftHand(factory, size), 2)) { + assertJoined( + factory, + join, + new int[][] { + { 0, 1, 2 }, // formatter + { 1, 3, 3 }, // formatter + }, + new Object[][] { + { "l00", 1, 2 }, // formatter + { "l01", 3, 3 }, // formatter + } + ); + assertJoined( + factory, + join, + new int[][] { + { 1, 9, 9 }, // formatter + { 2, 9, 9 }, // formatter + }, + new Object[][] { + { "l01", 9, 9 }, // formatter + { "l02", 9, 9 }, // formatter + } + ); + assertJoined( + factory, + join, + new int[][] { + { 5, 10, 10 }, // formatter + { 7, 11, 11 }, // formatter + }, + new Object[][] { + { "l03", null, null }, // formatter + { "l04", null, null }, // formatter + { "l05", 10, 10 }, // formatter + { "l06", null, null }, // formatter + { "l07", 11, 11 }, // formatter + } + ); + assertTrailing(join, size, 8); + } + } + + public void testThirdExample() { + testThirdExample(blockFactory()); + } + + public void testThirdExampleCranky() { + testWithCrankyBlockFactory(this::testThirdExample); + } + + /** + * Test the third example in the main javadoc of {@link RightChunkedLeftJoin}. + */ + private void testThirdExample(BlockFactory factory) { + int size = 100; + try (RightChunkedLeftJoin join = new RightChunkedLeftJoin(buildExampleLeftHand(factory, size), 2)) { + Page pre = buildPage(factory, IntStream.range(0, 96).mapToObj(p -> new int[] { p, p, p }).toArray(int[][]::new)); + try { + join.join(pre).releaseBlocks(); + } finally { + pre.releaseBlocks(); + } + assertJoined( + factory, + join, + new int[][] { + { 96, 1, 2 }, // formatter + { 97, 3, 3 }, // formatter + }, + new Object[][] { + { "l96", 1, 2 }, // formatter + { "l97", 3, 3 }, // formatter + } + ); + assertTrailing(join, size, 98); + } + } + + public void testRandom() { + testRandom(blockFactory()); + } + + public void testRandomCranky() { + testWithCrankyBlockFactory(this::testRandom); + } + + private void testRandom(BlockFactory factory) { + int leftSize = between(100, 10000); + ElementType[] leftColumns = randomArray(1, 10, ElementType[]::new, BasicBlockTests::randomElementType); + ElementType[] rightColumns = randomArray(1, 10, ElementType[]::new, BasicBlockTests::randomElementType); + + RandomPage left = randomPage(factory, leftColumns, leftSize); + try (RightChunkedLeftJoin join = new RightChunkedLeftJoin(left.page, rightColumns.length)) { + int rightSize = 5; + IntVector selected = randomPositions(factory, leftSize, rightSize); + RandomPage right = randomPage(factory, rightColumns, rightSize, selected.asBlock()); + + try { + Page joined = join.join(right.page); + try { + assertThat(joined.getPositionCount(), equalTo(selected.max() + 1)); + + List> actualColumns = new ArrayList<>(); + BlockTestUtils.readInto(actualColumns, joined); + int rightRow = 0; + for (int leftRow = 0; leftRow < joined.getPositionCount(); leftRow++) { + List actualRow = new ArrayList<>(); + for (List actualColumn : actualColumns) { + actualRow.add(actualColumn.get(leftRow)); + } + ListMatcher matcher = ListMatcher.matchesList(); + for (int c = 0; c < leftColumns.length; c++) { + matcher = matcher.item(unwrapSingletonLists(left.blocks[c].values().get(leftRow))); + } + if (selected.getInt(rightRow) == leftRow) { + for (int c = 0; c < rightColumns.length; c++) { + matcher = matcher.item(unwrapSingletonLists(right.blocks[c].values().get(rightRow))); + } + rightRow++; + } else { + for (int c = 0; c < rightColumns.length; c++) { + matcher = matcher.item(null); + } + } + assertMap(actualRow, matcher); + } + } finally { + joined.releaseBlocks(); + } + + int start = selected.max() + 1; + if (start >= left.page.getPositionCount()) { + assertThat(join.noMoreRightHandPages().isPresent(), equalTo(false)); + return; + } + Page remaining = join.noMoreRightHandPages().get(); + try { + assertThat(remaining.getPositionCount(), equalTo(left.page.getPositionCount() - start)); + List> actualColumns = new ArrayList<>(); + BlockTestUtils.readInto(actualColumns, remaining); + for (int leftRow = start; leftRow < left.page.getPositionCount(); leftRow++) { + List actualRow = new ArrayList<>(); + for (List actualColumn : actualColumns) { + actualRow.add(actualColumn.get(leftRow - start)); + } + ListMatcher matcher = ListMatcher.matchesList(); + for (int c = 0; c < leftColumns.length; c++) { + matcher = matcher.item(unwrapSingletonLists(left.blocks[c].values().get(leftRow))); + } + for (int c = 0; c < rightColumns.length; c++) { + matcher = matcher.item(null); + } + assertMap(actualRow, matcher); + } + + } finally { + remaining.releaseBlocks(); + } + } finally { + right.page.releaseBlocks(); + } + } finally { + left.page.releaseBlocks(); + } + } + + NumberFormat exampleNumberFormat(int size) { + NumberFormat nf = NumberFormat.getIntegerInstance(Locale.ROOT); + nf.setMinimumIntegerDigits((int) Math.ceil(Math.log10(size))); + return nf; + } + + Page buildExampleLeftHand(BlockFactory factory, int size) { + NumberFormat nf = exampleNumberFormat(size); + try (BytesRefVector.Builder builder = factory.newBytesRefVectorBuilder(size)) { + for (int i = 0; i < size; i++) { + builder.appendBytesRef(new BytesRef("l" + nf.format(i))); + } + return new Page(builder.build().asBlock()); + } + } + + Page buildPage(BlockFactory factory, int[][] rows) { + try ( + IntVector.Builder positions = factory.newIntVectorFixedBuilder(rows.length); + IntVector.Builder r1 = factory.newIntVectorFixedBuilder(rows.length); + IntVector.Builder r2 = factory.newIntVectorFixedBuilder(rows.length); + ) { + for (int[] row : rows) { + positions.appendInt(row[0]); + r1.appendInt(row[1]); + r2.appendInt(row[2]); + } + return new Page(positions.build().asBlock(), r1.build().asBlock(), r2.build().asBlock()); + } + } + + private void assertJoined(Page joined, Object[][] expected) { + try { + List> actualColumns = new ArrayList<>(); + BlockTestUtils.readInto(actualColumns, joined); + + for (int r = 0; r < expected.length; r++) { + List actualRow = new ArrayList<>(); + for (int c = 0; c < actualColumns.size(); c++) { + Object v = actualColumns.get(c).get(r); + if (v instanceof BytesRef b) { + v = b.utf8ToString(); + } + actualRow.add(v); + } + + ListMatcher rowMatcher = matchesList(); + for (Object v : expected[r]) { + rowMatcher = rowMatcher.item(v); + } + assertMap("row " + r, actualRow, rowMatcher); + } + } finally { + joined.releaseBlocks(); + } + } + + private void assertJoined(BlockFactory factory, RightChunkedLeftJoin join, int[][] rightRows, Object[][] expectRows) { + Page rightHand = buildPage(factory, rightRows); + try { + assertJoined(join.join(rightHand), expectRows); + } finally { + rightHand.releaseBlocks(); + } + } + + private void assertTrailing(RightChunkedLeftJoin join, int size, int next) { + NumberFormat nf = exampleNumberFormat(size); + if (size == next) { + assertThat(join.noMoreRightHandPages(), equalTo(Optional.empty())); + } else { + assertJoined( + join.noMoreRightHandPages().get(), + IntStream.range(next, size).mapToObj(p -> new Object[] { "l" + nf.format(p), null, null }).toArray(Object[][]::new) + ); + } + } + + Object unwrapSingletonLists(Object o) { + if (o instanceof List l && l.size() == 1) { + return l.getFirst(); + } + return o; + } + + record RandomPage(Page page, BasicBlockTests.RandomBlock[] blocks) {}; + + RandomPage randomPage(BlockFactory factory, ElementType[] types, int positions, Block... prepend) { + BasicBlockTests.RandomBlock[] randomBlocks = new BasicBlockTests.RandomBlock[types.length]; + Block[] blocks = new Block[prepend.length + types.length]; + try { + for (int c = 0; c < prepend.length; c++) { + blocks[c] = prepend[c]; + } + for (int c = 0; c < types.length; c++) { + + int min = between(0, 3); + randomBlocks[c] = BasicBlockTests.randomBlock( + factory, + types[c], + positions, + randomBoolean(), + min, + between(min, min + 3), + 0, + 0 + ); + blocks[prepend.length + c] = randomBlocks[c].block(); + } + Page p = new Page(blocks); + blocks = null; + return new RandomPage(p, randomBlocks); + } finally { + if (blocks != null) { + Releasables.close(blocks); + } + } + } + + IntVector randomPositions(BlockFactory factory, int leftSize, int positionCount) { + Set positions = new HashSet<>(); + while (positions.size() < positionCount) { + positions.add(between(0, leftSize - 1)); + } + int[] positionsArray = positions.stream().mapToInt(i -> i).sorted().toArray(); + return factory.newIntArrayVector(positionsArray, positionsArray.length); + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java index 3b9359fe66d40..f31eabea9d616 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/LookupFromIndexIT.java @@ -69,13 +69,56 @@ import static org.hamcrest.Matchers.empty; public class LookupFromIndexIT extends AbstractEsqlIntegTestCase { + // TODO should we remove this now that this is integrated into ESQL proper? /** * Quick and dirty test for looking up data from a lookup index. */ public void testLookupIndex() throws IOException { - // TODO this should *fail* if the target index isn't a lookup type index - it doesn't now. - int docCount = between(10, 1000); - List expected = new ArrayList<>(docCount); + runLookup(new UsingSingleLookupTable(new String[] { "aa", "bb", "cc", "dd" })); + } + + /** + * Tests when multiple results match. + */ + @AwaitsFix(bugUrl = "fixing real soon now") + public void testLookupIndexMultiResults() throws IOException { + runLookup(new UsingSingleLookupTable(new Object[] { "aa", new String[] { "bb", "ff" }, "cc", "dd" })); + } + + interface PopulateIndices { + void populate(int docCount, List expected) throws IOException; + } + + class UsingSingleLookupTable implements PopulateIndices { + private final Object[] lookupData; + + UsingSingleLookupTable(Object[] lookupData) { + this.lookupData = lookupData; + } + + @Override + public void populate(int docCount, List expected) throws IOException { + List docs = new ArrayList<>(); + for (int i = 0; i < docCount; i++) { + docs.add(client().prepareIndex("source").setSource(Map.of("data", lookupData[i % lookupData.length]))); + Object d = lookupData[i % lookupData.length]; + if (d instanceof String s) { + expected.add(s + ":" + (i % lookupData.length)); + } else if (d instanceof String[] ss) { + for (String s : ss) { + expected.add(s + ":" + (i % lookupData.length)); + } + } + } + for (int i = 0; i < lookupData.length; i++) { + docs.add(client().prepareIndex("lookup").setSource(Map.of("data", lookupData[i], "l", i))); + } + Collections.sort(expected); + indexRandom(true, true, docs); + } + } + + private void runLookup(PopulateIndices populateIndices) throws IOException { client().admin() .indices() .prepareCreate("source") @@ -95,17 +138,9 @@ public void testLookupIndex() throws IOException { .get(); client().admin().cluster().prepareHealth(TEST_REQUEST_TIMEOUT).setWaitForGreenStatus().get(); - String[] data = new String[] { "aa", "bb", "cc", "dd" }; - List docs = new ArrayList<>(); - for (int i = 0; i < docCount; i++) { - docs.add(client().prepareIndex("source").setSource(Map.of("data", data[i % data.length]))); - expected.add(data[i % data.length] + ":" + (i % data.length)); - } - for (int i = 0; i < data.length; i++) { - docs.add(client().prepareIndex("lookup").setSource(Map.of("data", data[i], "l", i))); - } - Collections.sort(expected); - indexRandom(true, true, docs); + int docCount = between(10, 1000); + List expected = new ArrayList<>(docCount); + populateIndices.populate(docCount, expected); /* * Find the data node hosting the only shard of the source index.