diff --git a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java
index 444dbcc1b9e58..8096153459003 100644
--- a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java
+++ b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java
@@ -37,11 +37,6 @@
* are ever collected.
*
*
- * The generation code will also look for a method called {@code combineValueCount}
- * which is called once per received block with a count of values. NOTE: We may
- * not need this after we convert AVG into a composite operation.
- *
- *
* The generation code also looks for the optional methods {@code combineIntermediate}
* and {@code evaluateFinal} which are used to combine intermediate states and
* produce the final output. If the first is missing then the generated code will
@@ -63,4 +58,8 @@
*/
Class extends Exception>[] warnExceptions() default {};
+ /**
+ * If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function.
+ */
+ boolean includeTimestamps() default false;
}
diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java
index 46881bf337c89..d775a46109214 100644
--- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java
+++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java
@@ -17,12 +17,12 @@
import org.elasticsearch.compute.ann.Aggregator;
import org.elasticsearch.compute.ann.IntermediateState;
+import org.elasticsearch.compute.gen.Methods.TypeMatcher;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
-import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -34,27 +34,24 @@
import javax.lang.model.util.Elements;
import static java.util.stream.Collectors.joining;
-import static org.elasticsearch.compute.gen.Methods.findMethod;
-import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
+import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
+import static org.elasticsearch.compute.gen.Methods.requireAnyType;
+import static org.elasticsearch.compute.gen.Methods.requireArgs;
+import static org.elasticsearch.compute.gen.Methods.requireName;
+import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
+import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
+import static org.elasticsearch.compute.gen.Methods.requireType;
+import static org.elasticsearch.compute.gen.Methods.requireVoidType;
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION;
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
-import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
-import static org.elasticsearch.compute.gen.Types.BYTES_REF_BLOCK;
-import static org.elasticsearch.compute.gen.Types.BYTES_REF_VECTOR;
-import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK;
-import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE;
-import static org.elasticsearch.compute.gen.Types.FLOAT_BLOCK;
-import static org.elasticsearch.compute.gen.Types.FLOAT_VECTOR;
import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC;
-import static org.elasticsearch.compute.gen.Types.INT_BLOCK;
-import static org.elasticsearch.compute.gen.Types.INT_VECTOR;
import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
@@ -78,46 +75,41 @@ public class AggregatorImplementer {
private final List warnExceptions;
private final ExecutableElement init;
private final ExecutableElement combine;
- private final ExecutableElement combineValueCount;
- private final ExecutableElement combineIntermediate;
- private final ExecutableElement evaluateFinal;
+ private final List createParameters;
private final ClassName implementation;
- private final TypeName stateType;
- private final boolean stateTypeHasSeen;
- private final boolean stateTypeHasFailed;
- private final boolean valuesIsBytesRef;
- private final boolean valuesIsArray;
private final List intermediateState;
- private final List createParameters;
+ private final boolean includeTimestampVector;
+
+ private final AggregationState aggState;
+ private final AggregationParameter aggParam;
public AggregatorImplementer(
Elements elements,
TypeElement declarationType,
IntermediateState[] interStateAnno,
- List warnExceptions
+ List warnExceptions,
+ boolean includeTimestampVector
) {
this.declarationType = declarationType;
this.warnExceptions = warnExceptions;
- this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true);
- this.stateType = choseStateType();
- this.stateTypeHasSeen = elements.getAllMembers(elements.getTypeElement(stateType.toString()))
- .stream()
- .anyMatch(e -> e.toString().equals("seen()"));
- this.stateTypeHasFailed = elements.getAllMembers(elements.getTypeElement(stateType.toString()))
- .stream()
- .anyMatch(e -> e.toString().equals("failed()"));
+ this.init = requireStaticMethod(
+ declarationType,
+ requirePrimitiveOrImplements(elements, Types.AGGREGATOR_STATE),
+ requireName("init", "initSingle"),
+ requireAnyArgs("")
+ );
+ this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, false);
+
+ this.combine = requireStaticMethod(
+ declarationType,
+ aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
+ requireName("combine"),
+ combineArgs(aggState, includeTimestampVector)
+ );
+ // TODO support multiple parameters
+ this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType());
- this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
- if (e.getParameters().size() == 0) {
- return false;
- }
- TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
- return firstParamType.isPrimitive() || firstParamType.toString().equals(stateType.toString());
- });
- this.combineValueCount = findMethod(declarationType, "combineValueCount");
- this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
- this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
this.createParameters = init.getParameters()
.stream()
.map(Parameter::from)
@@ -128,9 +120,20 @@ public AggregatorImplementer(
elements.getPackageOf(declarationType).toString(),
(declarationType.getSimpleName() + "AggregatorFunction").replace("AggregatorAggregator", "Aggregator")
);
- this.valuesIsBytesRef = BYTES_REF.equals(valueTypeName());
- this.valuesIsArray = TypeKind.ARRAY.equals(valueTypeKind());
- intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList();
+ this.intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList();
+ this.includeTimestampVector = includeTimestampVector;
+ }
+
+ private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) {
+ if (includeTimestampVector) {
+ return requireArgs(
+ requireType(aggState.declaredType()),
+ requireType(TypeName.LONG), // @timestamp
+ requireAnyType("")
+ );
+ } else {
+ return requireArgs(requireType(aggState.declaredType()), requireAnyType(""));
+ }
}
ClassName implementation() {
@@ -141,68 +144,8 @@ List createParameters() {
return createParameters;
}
- private TypeName choseStateType() {
- TypeName initReturn = TypeName.get(init.getReturnType());
- if (false == initReturn.isPrimitive()) {
- return initReturn;
- }
- String simpleName = firstUpper(initReturn.toString());
- if (warnExceptions.isEmpty()) {
- return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "State");
- }
- return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "FallibleState");
- }
-
- static String valueType(ExecutableElement init, ExecutableElement combine) {
- if (combine != null) {
- // If there's an explicit combine function it's final parameter is the type of the value.
- return combine.getParameters().get(combine.getParameters().size() - 1).asType().toString();
- }
- String initReturn = init.getReturnType().toString();
- switch (initReturn) {
- case "double":
- return "double";
- case "float":
- return "float";
- case "long":
- return "long";
- case "int":
- return "int";
- case "boolean":
- return "boolean";
- default:
- throw new IllegalArgumentException("unknown primitive type for " + initReturn);
- }
- }
-
- static ClassName valueBlockType(ExecutableElement init, ExecutableElement combine) {
- return switch (valueType(init, combine)) {
- case "boolean" -> BOOLEAN_BLOCK;
- case "double" -> DOUBLE_BLOCK;
- case "float" -> FLOAT_BLOCK;
- case "long" -> LONG_BLOCK;
- case "int", "int[]" -> INT_BLOCK;
- case "org.apache.lucene.util.BytesRef" -> BYTES_REF_BLOCK;
- default -> throw new IllegalArgumentException("unknown block type for " + valueType(init, combine));
- };
- }
-
- static ClassName valueVectorType(ExecutableElement init, ExecutableElement combine) {
- return switch (valueType(init, combine)) {
- case "boolean" -> BOOLEAN_VECTOR;
- case "double" -> DOUBLE_VECTOR;
- case "float" -> FLOAT_VECTOR;
- case "long" -> LONG_VECTOR;
- case "int", "int[]" -> INT_VECTOR;
- case "org.apache.lucene.util.BytesRef" -> BYTES_REF_VECTOR;
- default -> throw new IllegalArgumentException("unknown vector type for " + valueType(init, combine));
- };
- }
-
- public static String firstUpper(String s) {
- String head = s.toString().substring(0, 1).toUpperCase(Locale.ROOT);
- String tail = s.toString().substring(1);
- return head + tail;
+ public static String capitalize(String s) {
+ return Character.toUpperCase(s.charAt(0)) + s.substring(1);
}
public JavaFile sourceFile() {
@@ -232,7 +175,7 @@ private TypeSpec type() {
}
builder.addField(DRIVER_CONTEXT, "driverContext", Modifier.PRIVATE, Modifier.FINAL);
- builder.addField(stateType, "state", Modifier.PRIVATE, Modifier.FINAL);
+ builder.addField(aggState.type, "state", Modifier.PRIVATE, Modifier.FINAL);
builder.addField(LIST_INTEGER, "channels", Modifier.PRIVATE, Modifier.FINAL);
for (Parameter p : createParameters) {
@@ -292,10 +235,10 @@ private CodeBlock callInit() {
.map(p -> TypeName.get(p.asType()).equals(BIG_ARRAYS) ? "driverContext.bigArrays()" : p.getSimpleName().toString())
.collect(joining(", "));
CodeBlock.Builder builder = CodeBlock.builder();
- if (init.getReturnType().toString().equals(stateType.toString())) {
- builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall);
+ if (aggState.declaredType().isPrimitive()) {
+ builder.add("new $T($T.$L($L))", aggState.type(), declarationType, init.getSimpleName(), initParametersCall);
} else {
- builder.add("new $T($T.$L($L))", stateType, declarationType, init.getSimpleName(), initParametersCall);
+ builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall);
}
return builder.build();
}
@@ -320,7 +263,7 @@ private MethodSpec ctor() {
}
builder.addParameter(DRIVER_CONTEXT, "driverContext");
builder.addParameter(LIST_INTEGER, "channels");
- builder.addParameter(stateType, "state");
+ builder.addParameter(aggState.type, "state");
if (warnExceptions.isEmpty() == false) {
builder.addStatement("this.warnings = warnings");
@@ -352,7 +295,7 @@ private MethodSpec intermediateBlockCount() {
private MethodSpec addRawInput() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput");
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).addParameter(PAGE, "page").addParameter(BOOLEAN_VECTOR, "mask");
- if (stateTypeHasFailed) {
+ if (aggState.hasFailed()) {
builder.beginControlFlow("if (state.failed())");
builder.addStatement("return");
builder.endControlFlow();
@@ -366,43 +309,62 @@ private MethodSpec addRawInput() {
builder.beginControlFlow("if (mask.allTrue())");
{
builder.addComment("No masking");
- builder.addStatement("$T block = page.getBlock(channels.get(0))", valueBlockType(init, combine));
- builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine));
+ builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type()));
+ builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type()));
+ if (includeTimestampVector) {
+ builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
+ builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);
+
+ builder.beginControlFlow("if (timestampsVector == null) ");
+ builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block");
+ builder.endControlFlow();
+ }
builder.beginControlFlow("if (vector != null)");
- builder.addStatement("addRawVector(vector)");
+ builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector)" : "addRawVector(vector)");
builder.nextControlFlow("else");
- builder.addStatement("addRawBlock(block)");
+ builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector)" : "addRawBlock(block)");
builder.endControlFlow();
builder.addStatement("return");
}
builder.endControlFlow();
builder.addComment("Some positions masked away, others kept");
- builder.addStatement("$T block = page.getBlock(channels.get(0))", valueBlockType(init, combine));
- builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine));
+ builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type()));
+ builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type()));
+ if (includeTimestampVector) {
+ builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
+ builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);
+
+ builder.beginControlFlow("if (timestampsVector == null) ");
+ builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block");
+ builder.endControlFlow();
+ }
builder.beginControlFlow("if (vector != null)");
- builder.addStatement("addRawVector(vector, mask)");
+ builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector, mask)" : "addRawVector(vector, mask)");
builder.nextControlFlow("else");
- builder.addStatement("addRawBlock(block, mask)");
+ builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector, mask)" : "addRawBlock(block, mask)");
builder.endControlFlow();
return builder.build();
}
private MethodSpec addRawVector(boolean masked) {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawVector");
- builder.addModifiers(Modifier.PRIVATE).addParameter(valueVectorType(init, combine), "vector");
+ builder.addModifiers(Modifier.PRIVATE).addParameter(vectorType(aggParam.type()), "vector");
+ if (includeTimestampVector) {
+ builder.addParameter(LONG_VECTOR, "timestamps");
+ }
if (masked) {
builder.addParameter(BOOLEAN_VECTOR, "mask");
}
- if (valuesIsArray) {
+ if (aggParam.isArray()) {
builder.addComment("This type does not support vectors because all values are multi-valued");
return builder.build();
}
- if (stateTypeHasSeen) {
+ if (aggState.hasSeen()) {
builder.addStatement("state.seen(true)");
}
- if (valuesIsBytesRef) {
+ if (aggParam.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
@@ -415,20 +377,20 @@ private MethodSpec addRawVector(boolean masked) {
combineRawInput(builder, "vector");
}
builder.endControlFlow();
- if (combineValueCount != null) {
- builder.addStatement("$T.combineValueCount(state, vector.getPositionCount())", declarationType);
- }
return builder.build();
}
private MethodSpec addRawBlock(boolean masked) {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawBlock");
- builder.addModifiers(Modifier.PRIVATE).addParameter(valueBlockType(init, combine), "block");
+ builder.addModifiers(Modifier.PRIVATE).addParameter(blockType(aggParam.type()), "block");
+ if (includeTimestampVector) {
+ builder.addParameter(LONG_VECTOR, "timestamps");
+ }
if (masked) {
builder.addParameter(BOOLEAN_VECTOR, "mask");
}
- if (valuesIsBytesRef) {
+ if (aggParam.isBytesRef()) {
// Add bytes_ref scratch var that will only be used for bytes_ref blocks/vectors
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
@@ -440,16 +402,16 @@ private MethodSpec addRawBlock(boolean masked) {
builder.beginControlFlow("if (block.isNull(p))");
builder.addStatement("continue");
builder.endControlFlow();
- if (stateTypeHasSeen) {
+ if (aggState.hasSeen()) {
builder.addStatement("state.seen(true)");
}
builder.addStatement("int start = block.getFirstValueIndex(p)");
builder.addStatement("int end = start + block.getValueCount(p)");
- if (valuesIsArray) {
- String arrayType = valueTypeString();
+ if (aggParam.isArray()) {
+ String arrayType = aggParam.type().toString().replace("[]", "");
builder.addStatement("$L[] valuesArray = new $L[end - start]", arrayType, arrayType);
builder.beginControlFlow("for (int i = start; i < end; i++)");
- builder.addStatement("valuesArray[i-start] = $L.get$L(i)", "block", firstUpper(arrayType));
+ builder.addStatement("valuesArray[i-start] = $L.get$L(i)", "block", capitalize(arrayType));
builder.endControlFlow();
combineRawInputForArray(builder, "valuesArray");
} else {
@@ -459,16 +421,13 @@ private MethodSpec addRawBlock(boolean masked) {
}
}
builder.endControlFlow();
- if (combineValueCount != null) {
- builder.addStatement("$T.combineValueCount(state, block.getTotalValueCount())", declarationType);
- }
return builder.build();
}
private void combineRawInput(MethodSpec.Builder builder, String blockVariable) {
TypeName returnType = TypeName.get(combine.getReturnType());
warningsBlock(builder, () -> {
- if (valuesIsBytesRef) {
+ if (aggParam.isBytesRef()) {
combineRawInputForBytesRef(builder, blockVariable);
} else if (returnType.isPrimitive()) {
combineRawInputForPrimitive(returnType, builder, blockVariable);
@@ -480,33 +439,57 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable) {
});
}
- private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) {
- builder.addStatement(
- "state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))",
- returnType,
- declarationType,
- returnType,
- blockVariable,
- firstUpper(combine.getParameters().get(1).asType().toString())
- );
+ private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) {
+ // scratch is a BytesRef var that must have been defined before the iteration starts
+ if (includeTimestampVector) {
+ builder.addStatement("$T.combine(state, timestamps.getLong(i), $L.getBytesRef(i, scratch))", declarationType, blockVariable);
+ } else {
+ builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable);
+ }
}
- private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
- warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
+ private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) {
+ if (includeTimestampVector) {
+ builder.addStatement(
+ "state.$TValue($T.combine(state.$TValue(), timestamps.getLong(i), $L.get$L(i)))",
+ returnType,
+ declarationType,
+ returnType,
+ blockVariable,
+ capitalize(combine.getParameters().get(1).asType().toString())
+ );
+ } else {
+ builder.addStatement(
+ "state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))",
+ returnType,
+ declarationType,
+ returnType,
+ blockVariable,
+ capitalize(combine.getParameters().get(1).asType().toString())
+ );
+ }
}
private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable) {
- builder.addStatement(
- "$T.combine(state, $L.get$L(i))",
- declarationType,
- blockVariable,
- firstUpper(combine.getParameters().get(1).asType().toString())
- );
+ if (includeTimestampVector) {
+ builder.addStatement(
+ "$T.combine(state, timestamps.getLong(i), $L.get$L(i))",
+ declarationType,
+ blockVariable,
+ capitalize(combine.getParameters().get(1).asType().toString())
+ );
+ } else {
+ builder.addStatement(
+ "$T.combine(state, $L.get$L(i))",
+ declarationType,
+ blockVariable,
+ capitalize(combine.getParameters().get(1).asType().toString())
+ );
+ }
}
- private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) {
- // scratch is a BytesRef var that must have been defined before the iteration starts
- builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable);
+ private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
+ warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable));
}
private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
@@ -534,12 +517,7 @@ private MethodSpec addIntermediateInput() {
interState.assignToVariable(builder, i);
builder.addStatement("assert $L.getPositionCount() == 1", interState.name());
}
- if (combineIntermediate != null) {
- if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
- builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
- }
- builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
- } else if (hasPrimitiveState()) {
+ if (aggState.declaredType().isPrimitive()) {
if (warnExceptions.isEmpty()) {
assert intermediateState.size() == 2;
assert intermediateState.get(1).name().equals("seen");
@@ -557,14 +535,36 @@ private MethodSpec addIntermediateInput() {
}
warningsBlock(builder, () -> {
+ var primitiveStateMethod = switch (aggState.declaredType().toString()) {
+ case "boolean" -> "booleanValue";
+ case "int" -> "intValue";
+ case "long" -> "longValue";
+ case "double" -> "doubleValue";
+ case "float" -> "floatValue";
+ default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]");
+ };
var state = intermediateState.get(0);
var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))";
- builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod());
+ builder.addStatement(s, primitiveStateMethod, declarationType, primitiveStateMethod);
builder.addStatement("state.seen(true)");
});
builder.endControlFlow();
} else {
- throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate");
+ requireStaticMethod(
+ declarationType,
+ requireVoidType(),
+ requireName("combineIntermediate"),
+ requireArgs(
+ Stream.concat(
+ Stream.of(aggState.declaredType()), // aggState
+ intermediateState.stream().map(IntermediateStateDesc::combineArgType) // intermediate state
+ ).map(Methods::requireType).toArray(TypeMatcher[]::new)
+ )
+ );
+ if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) {
+ builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
+ }
+ builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType);
}
return builder.build();
}
@@ -573,25 +573,6 @@ String intermediateStateRowAccess() {
return intermediateState.stream().map(desc -> desc.access("0")).collect(joining(", "));
}
- private String primitiveStateMethod() {
- switch (stateType.toString()) {
- case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState":
- return "booleanValue";
- case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState":
- return "intValue";
- case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState":
- return "longValue";
- case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState":
- return "doubleValue";
- case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState":
- return "floatValue";
- default:
- throw new IllegalArgumentException(
- "don't know how to fetch primitive values from " + stateType + ". define combineIntermediate."
- );
- }
- }
-
private MethodSpec evaluateIntermediate() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("evaluateIntermediate");
builder.addAnnotation(Override.class)
@@ -610,45 +591,39 @@ private MethodSpec evaluateFinal() {
.addParameter(BLOCK_ARRAY, "blocks")
.addParameter(TypeName.INT, "offset")
.addParameter(DRIVER_CONTEXT, "driverContext");
- if (stateTypeHasSeen || stateTypeHasFailed) {
- var condition = Stream.of(stateTypeHasSeen ? "state.seen() == false" : null, stateTypeHasFailed ? "state.failed()" : null)
- .filter(Objects::nonNull)
- .collect(joining(" || "));
- builder.beginControlFlow("if ($L)", condition);
+ if (aggState.hasSeen() || aggState.hasFailed()) {
+ builder.beginControlFlow(
+ "if ($L)",
+ Stream.concat(
+ Stream.of("state.seen() == false").filter(c -> aggState.hasSeen()),
+ Stream.of("state.failed()").filter(c -> aggState.hasFailed())
+ ).collect(joining(" || "))
+ );
builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1)", BLOCK);
builder.addStatement("return");
builder.endControlFlow();
}
- if (evaluateFinal == null) {
- primitiveStateToResult(builder);
+ if (aggState.declaredType().isPrimitive()) {
+ builder.addStatement(switch (aggState.declaredType().toString()) {
+ case "boolean" -> "blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)";
+ case "int" -> "blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)";
+ case "long" -> "blocks[offset] = driverContext.blockFactory().newConstantLongBlockWith(state.longValue(), 1)";
+ case "double" -> "blocks[offset] = driverContext.blockFactory().newConstantDoubleBlockWith(state.doubleValue(), 1)";
+ case "float" -> "blocks[offset] = driverContext.blockFactory().newConstantFloatBlockWith(state.floatValue(), 1)";
+ default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]");
+ });
} else {
+ requireStaticMethod(
+ declarationType,
+ requireType(BLOCK),
+ requireName("evaluateFinal"),
+ requireArgs(requireType(aggState.declaredType()), requireType(DRIVER_CONTEXT))
+ );
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType);
}
return builder.build();
}
- private void primitiveStateToResult(MethodSpec.Builder builder) {
- switch (stateType.toString()) {
- case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState":
- builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)");
- return;
- case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState":
- builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)");
- return;
- case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState":
- builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantLongBlockWith(state.longValue(), 1)");
- return;
- case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState":
- builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantDoubleBlockWith(state.doubleValue(), 1)");
- return;
- case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState":
- builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantFloatBlockWith(state.floatValue(), 1)");
- return;
- default:
- throw new IllegalArgumentException("don't know how to convert state to result: " + stateType);
- }
- }
-
private MethodSpec toStringMethod() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("toString");
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).returns(String.class);
@@ -667,14 +642,6 @@ private MethodSpec close() {
return builder.build();
}
- private static final Pattern PRIMITIVE_STATE_PATTERN = Pattern.compile(
- "org.elasticsearch.compute.aggregation.(Boolean|Int|Long|Double|Float)(Fallible)?State"
- );
-
- private boolean hasPrimitiveState() {
- return PRIMITIVE_STATE_PATTERN.matcher(stateType.toString()).matches();
- }
-
record IntermediateStateDesc(String name, String elementType, boolean block) {
static IntermediateStateDesc newIntermediateStateDesc(IntermediateState state) {
String type = state.type();
@@ -711,22 +678,57 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) {
builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast");
}
}
- }
- private TypeMirror valueTypeMirror() {
- return combine.getParameters().get(combine.getParameters().size() - 1).asType();
+ public TypeName combineArgType() {
+ var type = Types.fromString(elementType);
+ return block ? blockType(type) : type;
+ }
}
- private TypeName valueTypeName() {
- return TypeName.get(valueTypeMirror());
+ /**
+ * This represents the type returned by init method used to keep aggregation state
+ * @param declaredType declared state type as returned by init method
+ * @param type actual type used (we have some predefined state types for primitive values)
+ */
+ public record AggregationState(TypeName declaredType, TypeName type, boolean hasSeen, boolean hasFailed) {
+
+ public static AggregationState create(Elements elements, TypeMirror mirror, boolean hasFailures, boolean isArray) {
+ var declaredType = TypeName.get(mirror);
+ var stateType = declaredType.isPrimitive()
+ ? ClassName.get("org.elasticsearch.compute.aggregation", primitiveStateStoreClassname(declaredType, hasFailures, isArray))
+ : declaredType;
+ return new AggregationState(
+ declaredType,
+ stateType,
+ hasMethod(elements, stateType, "seen()"),
+ hasMethod(elements, stateType, "failed()")
+ );
+ }
+
+ private static String primitiveStateStoreClassname(TypeName declaredType, boolean hasFailures, boolean isArray) {
+ var name = capitalize(declaredType.toString());
+ if (hasFailures) {
+ name += "Fallible";
+ }
+ if (isArray) {
+ name += "Array";
+ }
+ return name + "State";
+ }
}
- private TypeKind valueTypeKind() {
- return valueTypeMirror().getKind();
+ public record AggregationParameter(TypeName type, boolean isArray) {
+
+ public static AggregationParameter create(TypeMirror mirror) {
+ return new AggregationParameter(TypeName.get(mirror), Objects.equals(mirror.getKind(), TypeKind.ARRAY));
+ }
+
+ public boolean isBytesRef() {
+ return Objects.equals(type, BYTES_REF);
+ }
}
- private String valueTypeString() {
- String valueTypeString = TypeName.get(valueTypeMirror()).toString();
- return valuesIsArray ? valueTypeString.substring(0, valueTypeString.length() - 2) : valueTypeString;
+ private static boolean hasMethod(Elements elements, TypeName type, String name) {
+ return elements.getAllMembers(elements.getTypeElement(type.toString())).stream().anyMatch(e -> e.toString().equals(name));
}
}
diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java
index 863db86eb934a..3ad2343ad1658 100644
--- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java
+++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java
@@ -87,7 +87,13 @@ public boolean process(Set extends TypeElement> set, RoundEnvironment roundEnv
);
if (aggClass.getAnnotation(Aggregator.class) != null) {
IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value();
- implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes);
+ implementer = new AggregatorImplementer(
+ env.getElementUtils(),
+ aggClass,
+ intermediateState,
+ warnExceptionsTypes,
+ aggClass.getAnnotation(Aggregator.class).includeTimestamps()
+ );
write(aggClass, "aggregator", implementer.sourceFile(), env);
}
GroupingAggregatorImplementer groupingAggregatorImplementer = null;
@@ -96,13 +102,12 @@ public boolean process(Set extends TypeElement> set, RoundEnvironment roundEnv
if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) {
intermediateState = aggClass.getAnnotation(Aggregator.class).value();
}
- boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps();
groupingAggregatorImplementer = new GroupingAggregatorImplementer(
env.getElementUtils(),
aggClass,
intermediateState,
warnExceptionsTypes,
- includeTimestamps
+ aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps()
);
write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env);
}
diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java
index 8224c73936b90..d2b6a0e011687 100644
--- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java
+++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java
@@ -17,28 +17,35 @@
import org.elasticsearch.compute.ann.Aggregator;
import org.elasticsearch.compute.ann.IntermediateState;
+import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationParameter;
+import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationState;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
-import java.util.regex.Pattern;
+import java.util.function.Function;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
-import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Elements;
import static java.util.stream.Collectors.joining;
-import static org.elasticsearch.compute.gen.AggregatorImplementer.firstUpper;
-import static org.elasticsearch.compute.gen.AggregatorImplementer.valueBlockType;
-import static org.elasticsearch.compute.gen.AggregatorImplementer.valueVectorType;
-import static org.elasticsearch.compute.gen.Methods.findMethod;
-import static org.elasticsearch.compute.gen.Methods.findRequiredMethod;
+import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize;
+import static org.elasticsearch.compute.gen.Methods.requireAnyArgs;
+import static org.elasticsearch.compute.gen.Methods.requireAnyType;
+import static org.elasticsearch.compute.gen.Methods.requireArgs;
+import static org.elasticsearch.compute.gen.Methods.requireName;
+import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements;
+import static org.elasticsearch.compute.gen.Methods.requireStaticMethod;
+import static org.elasticsearch.compute.gen.Methods.requireType;
+import static org.elasticsearch.compute.gen.Methods.requireVoidType;
import static org.elasticsearch.compute.gen.Methods.vectorAccessorName;
import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS;
+import static org.elasticsearch.compute.gen.Types.BLOCK;
import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY;
import static org.elasticsearch.compute.gen.Types.BYTES_REF;
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
@@ -55,6 +62,8 @@
import static org.elasticsearch.compute.gen.Types.PAGE;
import static org.elasticsearch.compute.gen.Types.SEEN_GROUP_IDS;
import static org.elasticsearch.compute.gen.Types.WARNINGS;
+import static org.elasticsearch.compute.gen.Types.blockType;
+import static org.elasticsearch.compute.gen.Types.vectorType;
/**
* Implements "GroupingAggregationFunction" from a class containing static methods
@@ -70,17 +79,14 @@ public class GroupingAggregatorImplementer {
private final List warnExceptions;
private final ExecutableElement init;
private final ExecutableElement combine;
- private final ExecutableElement combineStates;
- private final ExecutableElement evaluateFinal;
- private final ExecutableElement combineIntermediate;
- private final TypeName stateType;
- private final boolean valuesIsBytesRef;
- private final boolean valuesIsArray;
private final List createParameters;
private final ClassName implementation;
private final List intermediateState;
private final boolean includeTimestampVector;
+ private final AggregationState aggState;
+ private final AggregationParameter aggParam;
+
public GroupingAggregatorImplementer(
Elements elements,
TypeElement declarationType,
@@ -91,21 +97,23 @@ public GroupingAggregatorImplementer(
this.declarationType = declarationType;
this.warnExceptions = warnExceptions;
- this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true);
- this.stateType = choseStateType();
+ this.init = requireStaticMethod(
+ declarationType,
+ requirePrimitiveOrImplements(elements, Types.GROUPING_AGGREGATOR_STATE),
+ requireName("init", "initGrouping"),
+ requireAnyArgs("")
+ );
+ this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, true);
+
+ this.combine = requireStaticMethod(
+ declarationType,
+ aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(),
+ requireName("combine"),
+ combineArgs(aggState, includeTimestampVector)
+ );
+ // TODO support multiple parameters
+ this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType());
- this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> {
- if (e.getParameters().size() == 0) {
- return false;
- }
- TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType());
- return firstParamType.isPrimitive() || firstParamType.toString().equals(stateType.toString());
- });
- this.combineStates = findMethod(declarationType, "combineStates");
- this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
- this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
- this.valuesIsBytesRef = BYTES_REF.equals(valueTypeName());
- this.valuesIsArray = TypeKind.ARRAY.equals(valueTypeKind());
this.createParameters = init.getParameters()
.stream()
.map(Parameter::from)
@@ -117,12 +125,31 @@ public GroupingAggregatorImplementer(
(declarationType.getSimpleName() + "GroupingAggregatorFunction").replace("AggregatorGroupingAggregator", "GroupingAggregator")
);
- intermediateState = Arrays.stream(interStateAnno)
+ this.intermediateState = Arrays.stream(interStateAnno)
.map(AggregatorImplementer.IntermediateStateDesc::newIntermediateStateDesc)
.toList();
this.includeTimestampVector = includeTimestampVector;
}
+ private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) {
+ if (aggState.declaredType().isPrimitive()) {
+ return requireArgs(requireType(aggState.declaredType()), requireAnyType(""));
+ } else if (includeTimestampVector) {
+ return requireArgs(
+ requireType(aggState.declaredType()),
+ requireType(TypeName.INT),
+ requireType(TypeName.LONG), // @timestamp
+ requireAnyType("")
+ );
+ } else {
+ return requireArgs(
+ requireType(aggState.declaredType()),
+ requireType(TypeName.INT),
+ requireAnyType("")
+ );
+ }
+ }
+
public ClassName implementation() {
return implementation;
}
@@ -131,18 +158,6 @@ List createParameters() {
return createParameters;
}
- private TypeName choseStateType() {
- TypeName initReturn = TypeName.get(init.getReturnType());
- if (false == initReturn.isPrimitive()) {
- return initReturn;
- }
- String simpleName = firstUpper(initReturn.toString());
- if (warnExceptions.isEmpty()) {
- return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "ArrayState");
- }
- return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "FallibleArrayState");
- }
-
public JavaFile sourceFile() {
JavaFile.Builder builder = JavaFile.builder(implementation.packageName(), type());
builder.addFileComment("""
@@ -164,7 +179,7 @@ private TypeSpec type() {
.initializer(initInterState())
.build()
);
- builder.addField(stateType, "state", Modifier.PRIVATE, Modifier.FINAL);
+ builder.addField(aggState.type(), "state", Modifier.PRIVATE, Modifier.FINAL);
if (warnExceptions.isEmpty() == false) {
builder.addField(WARNINGS, "warnings", Modifier.PRIVATE, Modifier.FINAL);
}
@@ -180,10 +195,10 @@ private TypeSpec type() {
builder.addMethod(intermediateStateDesc());
builder.addMethod(intermediateBlockCount());
builder.addMethod(prepareProcessPage());
- builder.addMethod(addRawInputLoop(INT_VECTOR, valueBlockType(init, combine)));
- builder.addMethod(addRawInputLoop(INT_VECTOR, valueVectorType(init, combine)));
- builder.addMethod(addRawInputLoop(INT_BLOCK, valueBlockType(init, combine)));
- builder.addMethod(addRawInputLoop(INT_BLOCK, valueVectorType(init, combine)));
+ builder.addMethod(addRawInputLoop(INT_VECTOR, blockType(aggParam.type())));
+ builder.addMethod(addRawInputLoop(INT_VECTOR, vectorType(aggParam.type())));
+ builder.addMethod(addRawInputLoop(INT_BLOCK, blockType(aggParam.type())));
+ builder.addMethod(addRawInputLoop(INT_BLOCK, vectorType(aggParam.type())));
builder.addMethod(selectedMayContainUnseenGroups());
builder.addMethod(addIntermediateInput());
builder.addMethod(addIntermediateRowInput());
@@ -230,16 +245,16 @@ private CodeBlock callInit() {
.map(p -> TypeName.get(p.asType()).equals(BIG_ARRAYS) ? "driverContext.bigArrays()" : p.getSimpleName().toString())
.collect(joining(", "));
CodeBlock.Builder builder = CodeBlock.builder();
- if (init.getReturnType().toString().equals(stateType.toString())) {
- builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall);
- } else {
+ if (aggState.declaredType().isPrimitive()) {
builder.add(
"new $T(driverContext.bigArrays(), $T.$L($L))",
- stateType,
+ aggState.type(),
declarationType,
init.getSimpleName(),
initParametersCall
);
+ } else {
+ builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall);
}
return builder.build();
}
@@ -263,7 +278,7 @@ private MethodSpec ctor() {
builder.addParameter(WARNINGS, "warnings");
}
builder.addParameter(LIST_INTEGER, "channels");
- builder.addParameter(stateType, "state");
+ builder.addParameter(aggState.type(), "state");
builder.addParameter(DRIVER_CONTEXT, "driverContext");
if (warnExceptions.isEmpty() == false) {
builder.addStatement("this.warnings = warnings");
@@ -301,8 +316,8 @@ private MethodSpec prepareProcessPage() {
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).returns(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT);
builder.addParameter(SEEN_GROUP_IDS, "seenGroupIds").addParameter(PAGE, "page");
- builder.addStatement("$T valuesBlock = page.getBlock(channels.get(0))", valueBlockType(init, combine));
- builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine));
+ builder.addStatement("$T valuesBlock = page.getBlock(channels.get(0))", blockType(aggParam.type()));
+ builder.addStatement("$T valuesVector = valuesBlock.asVector()", vectorType(aggParam.type()));
if (includeTimestampVector) {
builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK);
builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR);
@@ -355,18 +370,17 @@ private TypeSpec addInput(Consumer addBlock) {
private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
boolean groupsIsBlock = groupsType.toString().endsWith("Block");
boolean valuesIsBlock = valuesType.toString().endsWith("Block");
- String methodName = "addRawInput";
- MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName);
+ MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput");
builder.addModifiers(Modifier.PRIVATE);
builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values");
if (includeTimestampVector) {
builder.addParameter(LONG_VECTOR, "timestamps");
}
- if (valuesIsBytesRef) {
+ if (aggParam.isBytesRef()) {
// Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors
builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF);
}
- if (valuesIsArray && valuesIsBlock == false) {
+ if (aggParam.isArray() && valuesIsBlock == false) {
builder.addComment("This type does not support vectors because all values are multi-valued");
return builder.build();
}
@@ -397,11 +411,11 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
builder.endControlFlow();
builder.addStatement("int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset)");
builder.addStatement("int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset)");
- if (valuesIsArray) {
- String arrayType = valueTypeString();
+ if (aggParam.isArray()) {
+ String arrayType = aggParam.type().toString().replace("[]", "");
builder.addStatement("$L[] valuesArray = new $L[valuesEnd - valuesStart]", arrayType, arrayType);
builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
- builder.addStatement("valuesArray[v-valuesStart] = $L.get$L(v)", "values", firstUpper(arrayType));
+ builder.addStatement("valuesArray[v-valuesStart] = $L.get$L(v)", "values", capitalize(arrayType));
builder.endControlFlow();
combineRawInputForArray(builder, "valuesArray");
} else {
@@ -422,14 +436,12 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
}
private void combineRawInput(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
- TypeName valueType = valueTypeName();
+ TypeName valueType = aggParam.type();
TypeName returnType = TypeName.get(combine.getReturnType());
warningsBlock(builder, () -> {
- if (valuesIsBytesRef) {
+ if (aggParam.isBytesRef()) {
combineRawInputForBytesRef(builder, blockVariable, offsetVariable);
- } else if (includeTimestampVector) {
- combineRawInputWithTimestamp(builder, offsetVariable);
} else if (valueType.isPrimitive() == false) {
throw new IllegalArgumentException("second parameter to combine must be a primitive, array or BytesRef: " + valueType);
} else if (returnType.isPrimitive()) {
@@ -442,48 +454,75 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S
});
}
- private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
- builder.addStatement(
- "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))",
- declarationType,
- blockVariable,
- firstUpper(valueTypeName().toString()),
- offsetVariable
- );
+ private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
+ // scratch is a BytesRef var that must have been defined before the iteration starts
+ if (includeTimestampVector) {
+ if (offsetVariable.contains(" + ")) {
+ builder.addStatement("var valuePosition = $L", offsetVariable);
+ offsetVariable = "valuePosition";
+ }
+ builder.addStatement(
+ "$T.combine(state, groupId, timestamps.getLong($L), $L.getBytesRef($L, scratch))",
+ declarationType,
+ offsetVariable,
+ blockVariable,
+ offsetVariable
+ );
+ } else {
+ builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable);
+ }
}
- private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
- warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
+ private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
+ if (includeTimestampVector) {
+ if (offsetVariable.contains(" + ")) {
+ builder.addStatement("var valuePosition = $L", offsetVariable);
+ offsetVariable = "valuePosition";
+ }
+ builder.addStatement(
+ "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
+ declarationType,
+ offsetVariable,
+ capitalize(aggParam.type().toString()),
+ offsetVariable
+ );
+ } else {
+ builder.addStatement(
+ "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))",
+ declarationType,
+ blockVariable,
+ capitalize(aggParam.type().toString()),
+ offsetVariable
+ );
+ }
}
private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
- builder.addStatement(
- "$T.combine(state, groupId, $L.get$L($L))",
- declarationType,
- blockVariable,
- firstUpper(valueTypeName().toString()),
- offsetVariable
- );
- }
-
- private void combineRawInputWithTimestamp(MethodSpec.Builder builder, String offsetVariable) {
- String blockType = firstUpper(valueTypeName().toString());
- if (offsetVariable.contains(" + ")) {
- builder.addStatement("var valuePosition = $L", offsetVariable);
- offsetVariable = "valuePosition";
+ if (includeTimestampVector) {
+ if (offsetVariable.contains(" + ")) {
+ builder.addStatement("var valuePosition = $L", offsetVariable);
+ offsetVariable = "valuePosition";
+ }
+ builder.addStatement(
+ "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
+ declarationType,
+ offsetVariable,
+ capitalize(aggParam.type().toString()),
+ offsetVariable
+ );
+ } else {
+ builder.addStatement(
+ "$T.combine(state, groupId, $L.get$L($L))",
+ declarationType,
+ blockVariable,
+ capitalize(aggParam.type().toString()),
+ offsetVariable
+ );
}
- builder.addStatement(
- "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))",
- declarationType,
- offsetVariable,
- blockType,
- offsetVariable
- );
}
- private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) {
- // scratch is a BytesRef var that must have been defined before the iteration starts
- builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable);
+ private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) {
+ warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable));
}
private void warningsBlock(MethodSpec.Builder builder, Runnable block) {
@@ -539,7 +578,7 @@ private MethodSpec addIntermediateInput() {
builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)");
{
builder.addStatement("int groupId = groups.getInt(groupPosition)");
- if (hasPrimitiveState()) {
+ if (aggState.declaredType().isPrimitive()) {
if (warnExceptions.isEmpty()) {
assert intermediateState.size() == 2;
assert intermediateState.get(1).name().equals("seen");
@@ -567,31 +606,33 @@ private MethodSpec addIntermediateInput() {
});
builder.endControlFlow();
} else {
- builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType);
+ var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block);
+ requireStaticMethod(
+ declarationType,
+ requireVoidType(),
+ requireName("combineIntermediate"),
+ requireArgs(
+ Stream.of(
+ Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId
+ intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType),
+ Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position
+ ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new)
+ )
+ );
+
+ builder.addStatement(
+ "$T.combineIntermediate(state, groupId, "
+ + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "))
+ + (stateHasBlock ? ", groupPosition + positionOffset" : "")
+ + ")",
+ declarationType
+ );
}
builder.endControlFlow();
}
return builder.build();
}
- String intermediateStateRowAccess() {
- String rowAccess = intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", "));
- if (intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block)) {
- rowAccess += ", groupPosition + positionOffset";
- }
- return rowAccess;
- }
-
- private void combineStates(MethodSpec.Builder builder) {
- if (combineStates == null) {
- builder.beginControlFlow("if (inState.hasValue(position))");
- builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
- builder.endControlFlow();
- return;
- }
- builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
- }
-
private MethodSpec addIntermediateRowInput() {
MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput");
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
@@ -601,9 +642,26 @@ private MethodSpec addIntermediateRowInput() {
builder.addStatement("throw new IllegalArgumentException($S + getClass() + $S + input.getClass())", "expected ", "; got ");
}
builder.endControlFlow();
- builder.addStatement("$T inState = (($T) input).state", stateType, implementation);
+ builder.addStatement("$T inState = (($T) input).state", aggState.type(), implementation);
builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
- combineStates(builder);
+ if (aggState.declaredType().isPrimitive()) {
+ builder.beginControlFlow("if (inState.hasValue(position))");
+ builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
+ builder.endControlFlow();
+ } else {
+ requireStaticMethod(
+ declarationType,
+ requireVoidType(),
+ requireName("combineStates"),
+ requireArgs(
+ requireType(aggState.declaredType()),
+ requireType(TypeName.INT),
+ requireType(aggState.declaredType()),
+ requireType(TypeName.INT)
+ )
+ );
+ builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType);
+ }
return builder.build();
}
@@ -627,9 +685,15 @@ private MethodSpec evaluateFinal() {
.addParameter(INT_VECTOR, "selected")
.addParameter(DRIVER_CONTEXT, "driverContext");
- if (evaluateFinal == null) {
+ if (aggState.declaredType().isPrimitive()) {
builder.addStatement("blocks[offset] = state.toValuesBlock(selected, driverContext)");
} else {
+ requireStaticMethod(
+ declarationType,
+ requireType(BLOCK),
+ requireName("evaluateFinal"),
+ requireArgs(requireType(aggState.declaredType()), requireType(INT_VECTOR), requireType(DRIVER_CONTEXT))
+ );
builder.addStatement("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)", declarationType);
}
return builder.build();
@@ -652,32 +716,4 @@ private MethodSpec close() {
builder.addStatement("state.close()");
return builder.build();
}
-
- private static final Pattern PRIMITIVE_STATE_PATTERN = Pattern.compile(
- "org.elasticsearch.compute.aggregation.(Boolean|Int|Long|Double|Float)(Fallible)?ArrayState"
- );
-
- private boolean hasPrimitiveState() {
- return PRIMITIVE_STATE_PATTERN.matcher(stateType.toString()).matches();
- }
-
- private TypeMirror valueTypeMirror() {
- return combine.getParameters().get(combine.getParameters().size() - 1).asType();
- }
-
- private TypeName valueTypeName() {
- return TypeName.get(valueTypeMirror());
- }
-
- private TypeKind valueTypeKind() {
- return valueTypeMirror().getKind();
- }
-
- private String valueTypeString() {
- String valueTypeString = TypeName.get(valueTypeMirror()).toString();
- if (valuesIsArray) {
- valueTypeString = valueTypeString.substring(0, valueTypeString.length() - 2);
- }
- return valueTypeString;
- }
}
diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java
index 6f98f1f797ab0..f2fa7b8084448 100644
--- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java
+++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java
@@ -9,18 +9,22 @@
import com.squareup.javapoet.TypeName;
-import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
import java.util.function.Predicate;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
-import javax.lang.model.element.Element;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
-import javax.lang.model.element.VariableElement;
import javax.lang.model.type.DeclaredType;
-import javax.lang.model.type.TypeMirror;
+import javax.lang.model.type.TypeKind;
import javax.lang.model.util.ElementFilter;
+import javax.lang.model.util.Elements;
+import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK_BUILDER;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
@@ -49,30 +53,116 @@
* Finds declared methods for the code generator.
*/
public class Methods {
- static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate filter) {
- ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType));
- if (result == null) {
- if (names.length == 1) {
- throw new IllegalArgumentException(declarationType + "#" + names[0] + " is required");
- }
- throw new IllegalArgumentException("one of " + declarationType + "#" + Arrays.toString(names) + " is required");
+
+ static ExecutableElement requireStaticMethod(
+ TypeElement declarationType,
+ TypeMatcher returnTypeMatcher,
+ NameMatcher nameMatcher,
+ ArgumentMatcher argumentMatcher
+ ) {
+ return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream())
+ .filter(method -> method.getModifiers().contains(Modifier.STATIC))
+ .filter(method -> nameMatcher.test(method.getSimpleName().toString()))
+ .filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType())))
+ .filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList()))
+ .findFirst()
+ .orElseThrow(() -> {
+ var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
+ var signatures = nameMatcher.names.stream()
+ .map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
+ .collect(joining(" or "));
+ return new IllegalArgumentException(message + signatures);
+ });
+ }
+
+ static NameMatcher requireName(String... names) {
+ return new NameMatcher(Set.of(names));
+ }
+
+ static TypeMatcher requireVoidType() {
+ return new TypeMatcher(type -> Objects.equals(TypeName.VOID, type), "void");
+ }
+
+ static TypeMatcher requireAnyType(String description) {
+ return new TypeMatcher(type -> true, description);
+ }
+
+ static TypeMatcher requirePrimitiveOrImplements(Elements elements, TypeName requiredInterface) {
+ return new TypeMatcher(
+ type -> type.isPrimitive() || isImplementing(elements, type, requiredInterface),
+ "[boolean|int|long|float|double|" + requiredInterface + "]"
+ );
+ }
+
+ static TypeMatcher requireType(TypeName requiredType) {
+ return new TypeMatcher(type -> Objects.equals(requiredType, type), requiredType.toString());
+ }
+
+ static ArgumentMatcher requireAnyArgs(String description) {
+ return new ArgumentMatcher(args -> true, description);
+ }
+
+ static ArgumentMatcher requireArgs(TypeMatcher... argTypes) {
+ return new ArgumentMatcher(
+ args -> args.size() == argTypes.length && IntStream.range(0, argTypes.length).allMatch(i -> argTypes[i].test(args.get(i))),
+ Stream.of(argTypes).map(TypeMatcher::toString).collect(joining(", "))
+ );
+ }
+
+ record NameMatcher(Set names) implements Predicate {
+ @Override
+ public boolean test(String name) {
+ return names.contains(name);
}
- return result;
}
- static ExecutableElement findMethod(TypeElement declarationType, String name) {
- return findMethod(new String[] { name }, e -> true, declarationType, superClassOf(declarationType));
+ record TypeMatcher(Predicate matcher, String description) implements Predicate {
+ @Override
+ public boolean test(TypeName typeName) {
+ return matcher.test(typeName);
+ }
+
+ @Override
+ public String toString() {
+ return description;
+ }
}
- private static TypeElement superClassOf(TypeElement declarationType) {
- TypeMirror superclass = declarationType.getSuperclass();
- if (superclass instanceof DeclaredType declaredType) {
- Element superclassElement = declaredType.asElement();
- if (superclassElement instanceof TypeElement) {
- return (TypeElement) superclassElement;
- }
+ record ArgumentMatcher(Predicate> matcher, String description) implements Predicate> {
+ @Override
+ public boolean test(List typeName) {
+ return matcher.test(typeName);
+ }
+
+ @Override
+ public String toString() {
+ return description;
+ }
+ }
+
+ private static boolean isImplementing(Elements elements, TypeName type, TypeName requiredInterface) {
+ return allInterfacesOf(elements, type).anyMatch(
+ anInterface -> Objects.equals(anInterface.toString(), requiredInterface.toString())
+ );
+ }
+
+ private static Stream allInterfacesOf(Elements elements, TypeName type) {
+ var typeElement = elements.getTypeElement(type.toString());
+ var superType = Stream.of(typeElement.getSuperclass()).filter(sType -> sType.getKind() != TypeKind.NONE).map(TypeName::get);
+ var interfaces = typeElement.getInterfaces().stream().map(TypeName::get);
+ return Stream.concat(
+ superType.flatMap(sType -> allInterfacesOf(elements, sType)),
+ interfaces.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface)))
+ );
+ }
+
+ private static Stream typeAndSuperType(TypeElement declarationType) {
+ if (declarationType.getSuperclass() instanceof DeclaredType declaredType
+ && declaredType.asElement() instanceof TypeElement superType) {
+ return Stream.of(declarationType, superType);
+ } else {
+ return Stream.of(declarationType);
}
- return null;
}
static ExecutableElement findMethod(TypeElement declarationType, String[] names, Predicate filter) {
@@ -95,16 +185,6 @@ static ExecutableElement findMethod(String[] names, Predicate
return null;
}
- /**
- * Returns the arguments of a method after applying a filter.
- */
- static VariableElement[] findMethodArguments(ExecutableElement method, Predicate filter) {
- if (method.getParameters().isEmpty()) {
- return new VariableElement[0];
- }
- return method.getParameters().stream().filter(filter).toArray(VariableElement[]::new);
- }
-
/**
* Returns the name of the method used to add {@code valueType} instances
* to vector or block builders.
diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java
index 8b01d957f3bd2..35c42153f9ad6 100644
--- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java
+++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java
@@ -15,9 +15,13 @@
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
+import java.util.Map;
+import java.util.stream.Stream;
import javax.lang.model.type.TypeMirror;
+import static java.util.stream.Collectors.toUnmodifiableMap;
+
/**
* Types used by the code generator.
*/
@@ -75,26 +79,8 @@ public class Types {
static final ClassName DOUBLE_VECTOR_FIXED_BUILDER = ClassName.get(DATA_PACKAGE, "DoubleVector", "FixedBuilder");
static final ClassName FLOAT_VECTOR_FIXED_BUILDER = ClassName.get(DATA_PACKAGE, "FloatVector", "FixedBuilder");
- static final ClassName BOOLEAN_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "BooleanArrayVector");
- static final ClassName BYTES_REF_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "BytesRefArrayVector");
- static final ClassName INT_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "IntArrayVector");
- static final ClassName LONG_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "LongArrayVector");
- static final ClassName DOUBLE_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "DoubleArrayVector");
- static final ClassName FLOAT_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "FloatArrayVector");
-
- static final ClassName BOOLEAN_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "BooleanArrayBlock");
- static final ClassName BYTES_REF_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "BytesRefArrayBlock");
- static final ClassName INT_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "IntArrayBlock");
- static final ClassName LONG_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "LongArrayBlock");
- static final ClassName DOUBLE_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "DoubleArrayBlock");
- static final ClassName FLOAT_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "FloatArrayBlock");
-
- static final ClassName BOOLEAN_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantBooleanVector");
- static final ClassName BYTES_REF_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantBytesRefVector");
- static final ClassName INT_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantIntVector");
- static final ClassName LONG_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantLongVector");
- static final ClassName DOUBLE_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantDoubleVector");
- static final ClassName FLOAT_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantFloatVector");
+ static final ClassName AGGREGATOR_STATE = ClassName.get(AGGREGATION_PACKAGE, "AggregatorState");
+ static final ClassName GROUPING_AGGREGATOR_STATE = ClassName.get(AGGREGATION_PACKAGE, "GroupingAggregatorState");
static final ClassName AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunction");
static final ClassName AGGREGATOR_FUNCTION_SUPPLIER = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunctionSupplier");
@@ -138,89 +124,50 @@ public class Types {
static final ClassName RELEASABLE = ClassName.get("org.elasticsearch.core", "Releasable");
static final ClassName RELEASABLES = ClassName.get("org.elasticsearch.core", "Releasables");
- static ClassName blockType(TypeName elementType) {
- if (elementType.equals(TypeName.BOOLEAN)) {
- return BOOLEAN_BLOCK;
- }
- if (elementType.equals(BYTES_REF)) {
- return BYTES_REF_BLOCK;
- }
- if (elementType.equals(TypeName.INT)) {
- return INT_BLOCK;
- }
- if (elementType.equals(TypeName.LONG)) {
- return LONG_BLOCK;
- }
- if (elementType.equals(TypeName.DOUBLE)) {
- return DOUBLE_BLOCK;
+ private record TypeDef(TypeName type, String alias, ClassName block, ClassName vector) {
+
+ public static TypeDef of(TypeName type, String alias, String block, String vector) {
+ return new TypeDef(type, alias, ClassName.get(DATA_PACKAGE, block), ClassName.get(DATA_PACKAGE, vector));
}
- throw new IllegalArgumentException("unknown block type for [" + elementType + "]");
+ }
+
+ private static final Map TYPES = Stream.of(
+ TypeDef.of(TypeName.BOOLEAN, "BOOLEAN", "BooleanBlock", "BooleanVector"),
+ TypeDef.of(TypeName.INT, "INT", "IntBlock", "IntVector"),
+ TypeDef.of(TypeName.LONG, "LONG", "LongBlock", "LongVector"),
+ TypeDef.of(TypeName.FLOAT, "FLOAT", "FloatBlock", "FloatVector"),
+ TypeDef.of(TypeName.DOUBLE, "DOUBLE", "DoubleBlock", "DoubleVector"),
+ TypeDef.of(BYTES_REF, "BYTES_REF", "BytesRefBlock", "BytesRefVector")
+ )
+ .flatMap(def -> Stream.of(def.type.toString(), def.type + "[]", def.alias).map(alias -> Map.entry(alias, def)))
+ .collect(toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue));
+
+ private static TypeDef findRequired(String name, String kind) {
+ TypeDef typeDef = TYPES.get(name);
+ if (typeDef == null) {
+ throw new IllegalArgumentException("unknown " + kind + " type [" + name + "]");
+ }
+ return typeDef;
+ }
+
+ static TypeName fromString(String type) {
+ return findRequired(type, "plain").type;
+ }
+
+ static ClassName blockType(TypeName elementType) {
+ return blockType(elementType.toString());
}
static ClassName blockType(String elementType) {
- if (elementType.equalsIgnoreCase(TypeName.BOOLEAN.toString())) {
- return BOOLEAN_BLOCK;
- }
- if (elementType.equalsIgnoreCase("BYTES_REF")) {
- return BYTES_REF_BLOCK;
- }
- if (elementType.equalsIgnoreCase(TypeName.INT.toString())) {
- return INT_BLOCK;
- }
- if (elementType.equalsIgnoreCase(TypeName.LONG.toString())) {
- return LONG_BLOCK;
- }
- if (elementType.equalsIgnoreCase(TypeName.DOUBLE.toString())) {
- return DOUBLE_BLOCK;
- }
- if (elementType.equalsIgnoreCase(TypeName.FLOAT.toString())) {
- return FLOAT_BLOCK;
- }
- throw new IllegalArgumentException("unknown vector type for [" + elementType + "]");
+ return findRequired(elementType, "block").block;
}
static ClassName vectorType(TypeName elementType) {
- if (elementType.equals(TypeName.BOOLEAN)) {
- return BOOLEAN_VECTOR;
- }
- if (elementType.equals(BYTES_REF)) {
- return BYTES_REF_VECTOR;
- }
- if (elementType.equals(TypeName.INT)) {
- return INT_VECTOR;
- }
- if (elementType.equals(TypeName.LONG)) {
- return LONG_VECTOR;
- }
- if (elementType.equals(TypeName.DOUBLE)) {
- return DOUBLE_VECTOR;
- }
- if (elementType.equals(TypeName.FLOAT)) {
- return FLOAT_VECTOR;
- }
- throw new IllegalArgumentException("unknown vector type for [" + elementType + "]");
+ return vectorType(elementType.toString());
}
static ClassName vectorType(String elementType) {
- if (elementType.equalsIgnoreCase(TypeName.BOOLEAN.toString())) {
- return BOOLEAN_VECTOR;
- }
- if (elementType.equalsIgnoreCase("BYTES_REF")) {
- return BYTES_REF_VECTOR;
- }
- if (elementType.equalsIgnoreCase(TypeName.INT.toString())) {
- return INT_VECTOR;
- }
- if (elementType.equalsIgnoreCase(TypeName.LONG.toString())) {
- return LONG_VECTOR;
- }
- if (elementType.equalsIgnoreCase(TypeName.DOUBLE.toString())) {
- return DOUBLE_VECTOR;
- }
- if (elementType.equalsIgnoreCase(TypeName.FLOAT.toString())) {
- return FLOAT_VECTOR;
- }
- throw new IllegalArgumentException("unknown vector type for [" + elementType + "]");
+ return findRequired(elementType, "vector").vector;
}
static ClassName builderType(TypeName resultType) {
@@ -282,63 +229,6 @@ static ClassName vectorFixedBuilderType(TypeName elementType) {
throw new IllegalArgumentException("unknown vector fixed builder type for [" + elementType + "]");
}
- static ClassName arrayVectorType(TypeName elementType) {
- if (elementType.equals(TypeName.BOOLEAN)) {
- return BOOLEAN_ARRAY_VECTOR;
- }
- if (elementType.equals(BYTES_REF)) {
- return BYTES_REF_ARRAY_VECTOR;
- }
- if (elementType.equals(TypeName.INT)) {
- return INT_ARRAY_VECTOR;
- }
- if (elementType.equals(TypeName.LONG)) {
- return LONG_ARRAY_VECTOR;
- }
- if (elementType.equals(TypeName.DOUBLE)) {
- return DOUBLE_ARRAY_VECTOR;
- }
- throw new IllegalArgumentException("unknown vector type for [" + elementType + "]");
- }
-
- static ClassName arrayBlockType(TypeName elementType) {
- if (elementType.equals(TypeName.BOOLEAN)) {
- return BOOLEAN_ARRAY_BLOCK;
- }
- if (elementType.equals(BYTES_REF)) {
- return BYTES_REF_ARRAY_BLOCK;
- }
- if (elementType.equals(TypeName.INT)) {
- return INT_ARRAY_BLOCK;
- }
- if (elementType.equals(TypeName.LONG)) {
- return LONG_ARRAY_BLOCK;
- }
- if (elementType.equals(TypeName.DOUBLE)) {
- return DOUBLE_ARRAY_BLOCK;
- }
- throw new IllegalArgumentException("unknown vector type for [" + elementType + "]");
- }
-
- static ClassName constantVectorType(TypeName elementType) {
- if (elementType.equals(TypeName.BOOLEAN)) {
- return BOOLEAN_CONSTANT_VECTOR;
- }
- if (elementType.equals(BYTES_REF)) {
- return BYTES_REF_CONSTANT_VECTOR;
- }
- if (elementType.equals(TypeName.INT)) {
- return INT_CONSTANT_VECTOR;
- }
- if (elementType.equals(TypeName.LONG)) {
- return LONG_CONSTANT_VECTOR;
- }
- if (elementType.equals(TypeName.DOUBLE)) {
- return DOUBLE_CONSTANT_VECTOR;
- }
- throw new IllegalArgumentException("unknown vector type for [" + elementType + "]");
- }
-
static TypeName elementType(TypeName t) {
if (t.equals(BOOLEAN_BLOCK) || t.equals(BOOLEAN_VECTOR) || t.equals(BOOLEAN_BLOCK_BUILDER)) {
return TypeName.BOOLEAN;
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java
index cbd20f15c6511..deec1ef04f623 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java
@@ -333,7 +333,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) {
}
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// noop - we handle the null states inside `toIntermediate` and `evaluateFinal`
}
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java
index b50b125d98331..94ad5254bc723 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java
@@ -334,7 +334,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) {
}
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// noop - we handle the null states inside `toIntermediate` and `evaluateFinal`
}
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java
index 01c3e3d7fb8e7..011291dd08c52 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java
@@ -334,7 +334,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) {
}
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// noop - we handle the null states inside `toIntermediate` and `evaluateFinal`
}
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java
index c84985b703aed..9ccb5d3bd1b1a 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java
@@ -333,7 +333,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) {
}
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// noop - we handle the null states inside `toIntermediate` and `evaluateFinal`
}
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java
index 32391c4827303..a2e86b3b09340 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.BooleanBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final BooleanBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -107,7 +108,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -122,7 +123,8 @@ public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java
index c9b0e679b3e64..0a965899c0775 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java
@@ -19,7 +19,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.BytesRefBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -78,7 +77,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final BytesRefBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -95,7 +94,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -103,7 +103,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -113,7 +114,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -128,7 +129,8 @@ public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java
index d9a7a302f07c8..6a20ed99bc236 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.DoubleBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final DoubleBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -107,7 +108,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -122,7 +123,8 @@ public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java
index 8b65261e10f46..cf6ad0f9017de 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.FloatBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final FloatBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -107,7 +108,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -122,7 +123,8 @@ public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java
index 5c6b79f710af5..f4ac83c438063 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.IntBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final IntBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -107,7 +108,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -122,7 +123,8 @@ public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java
index 219f7385b56df..292dd539edeb5 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java
@@ -18,7 +18,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.sort.IpBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -77,7 +76,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final IpBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -92,7 +91,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -100,7 +100,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -110,7 +111,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -125,7 +126,8 @@ public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java
index 44cef8df7257b..c5af92956bec1 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.data.sort.LongBucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final LongBucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) {
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -107,7 +108,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -122,7 +123,8 @@ public void merge(GroupingState other) {
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java
index bd77bd7ff1e46..ad0ab2f7189f6 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java
@@ -20,7 +20,6 @@
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
@@ -83,14 +82,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final BytesRefHash values;
private SingleState(BigArrays bigArrays) {
values = new BytesRefHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
@@ -125,7 +125,7 @@ public void close() {
* an {@code O(n^2)} operation for collection to support a {@code O(1)}
* collector operation. But at least it's fairly simple.
*/
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final LongLongHash values;
private final BytesRefHash bytes;
@@ -146,7 +146,8 @@ private GroupingState(BigArrays bigArrays) {
}
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -190,7 +191,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
}
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java
index a8409367bc090..271d7120092ca 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java
@@ -18,7 +18,6 @@
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
/**
* Aggregates field values for double.
@@ -77,14 +76,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final LongHash values;
private SingleState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
@@ -118,14 +118,15 @@ public void close() {
* an {@code O(n^2)} operation for collection to support a {@code O(1)}
* collector operation. But at least it's fairly simple.
*/
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final LongLongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongLongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -168,7 +169,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
}
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java
index f9e5e1b7b283a..b44cad807fba2 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.FloatBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
/**
* Aggregates field values for float.
@@ -82,14 +81,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final LongHash values;
private SingleState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
@@ -123,14 +123,15 @@ public void close() {
* an {@code O(n^2)} operation for collection to support a {@code O(1)}
* collector operation. But at least it's fairly simple.
*/
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final LongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -175,7 +176,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
}
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java
index 2420dcee70712..4d0c518245694 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
/**
* Aggregates field values for int.
@@ -82,14 +81,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final LongHash values;
private SingleState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
@@ -123,14 +123,15 @@ public void close() {
* an {@code O(n^2)} operation for collection to support a {@code O(1)}
* collector operation. But at least it's fairly simple.
*/
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final LongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -175,7 +176,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
}
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java
index 4938b8f15edb0..5471c90147ec4 100644
--- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java
@@ -18,7 +18,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.LongBlock;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
/**
* Aggregates field values for long.
@@ -77,14 +76,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final LongHash values;
private SingleState(BigArrays bigArrays) {
values = new LongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
@@ -118,14 +118,15 @@ public void close() {
* an {@code O(n^2)} operation for collection to support a {@code O(1)}
* collector operation. But at least it's fairly simple.
*/
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final LongLongHash values;
private GroupingState(BigArrays bigArrays) {
values = new LongLongHash(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -168,7 +169,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
}
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java
index 5fa1394e8cf96..9886e0c1af306 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java
@@ -37,6 +37,7 @@ public boolean hasValue(int groupId) {
* idempotent and fast if already tracking so it's safe to, say, call it once
* for every block of values that arrives containing {@code null}.
*/
+ @Override
public final void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
if (seen == null) {
seen = seenGroupIds.seenGroupIds(bigArrays);
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java
index eb0a992c8610f..18b92c5447076 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java
@@ -138,7 +138,8 @@ boolean hasValue(int groupId) {
* stores a flag to know if optimizations can be made.
*
*/
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
this.groupIdTrackingEnabled = true;
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java
index 7c644342598dc..0e65164665808 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java
@@ -17,4 +17,5 @@ public interface GroupingAggregatorState extends Releasable {
/** Extracts an intermediate view of the contents of this state. */
void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext);
+ void enableGroupIdTracking(SeenGroupIds seenGroupIds);
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java
index 3d8d04d7dc7e3..64a970c2acc07 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java
@@ -138,7 +138,8 @@ static class GroupingState implements GroupingAggregatorState {
this.hll = new HyperLogLogPlusPlus(HyperLogLogPlusPlus.precisionFromThreshold(precision), bigArrays, 1);
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// Nothing to do
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java
index 144214f93571e..049642c350917 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.BreakingBytesRefBuilder;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
@@ -71,7 +70,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(selected, driverContext);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final BytesRefArrayState internalState;
private GroupingState(BigArrays bigArrays, CircuitBreaker breaker) {
@@ -90,7 +89,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) {
}
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
internalState.toIntermediate(blocks, offset, selected, driverContext);
}
@@ -98,7 +98,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) {
return internalState.toValuesBlock(selected, driverContext);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
internalState.enableGroupIdTracking(seen);
}
@@ -108,7 +109,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final BreakingBytesRefBuilder internalState;
private boolean seen;
@@ -128,7 +129,8 @@ public void add(BytesRef value) {
}
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState.bytesRefView(), 1);
blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1);
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java
index 1ddce7674ae7b..43b4a4a2fe0a1 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java
@@ -15,7 +15,6 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
@Aggregator({ @IntermediateState(name = "max", type = "BYTES_REF"), @IntermediateState(name = "seen", type = "BOOLEAN") })
@@ -67,7 +66,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(selected, driverContext);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final BytesRef scratch = new BytesRef();
private final IpArrayState internalState;
@@ -87,7 +86,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) {
}
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
internalState.toIntermediate(blocks, offset, selected, driverContext);
}
@@ -95,7 +95,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) {
return internalState.toValuesBlock(selected, driverContext);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
internalState.enableGroupIdTracking(seen);
}
@@ -105,7 +106,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final BytesRef internalState;
private boolean seen;
@@ -121,7 +122,8 @@ public void add(BytesRef value) {
}
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState, 1);
blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1);
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java
index 830900702a371..677b38a9af3a7 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.BreakingBytesRefBuilder;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
@@ -71,7 +70,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(selected, driverContext);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final BytesRefArrayState internalState;
private GroupingState(BigArrays bigArrays, CircuitBreaker breaker) {
@@ -90,7 +89,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) {
}
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
internalState.toIntermediate(blocks, offset, selected, driverContext);
}
@@ -98,7 +98,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) {
return internalState.toValuesBlock(selected, driverContext);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
internalState.enableGroupIdTracking(seen);
}
@@ -108,7 +109,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final BreakingBytesRefBuilder internalState;
private boolean seen;
@@ -128,7 +129,8 @@ public void add(BytesRef value) {
}
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState.bytesRefView(), 1);
blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1);
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java
index 8313756851c1f..c4ee93db89cf8 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java
@@ -15,7 +15,6 @@
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
@Aggregator({ @IntermediateState(name = "max", type = "BYTES_REF"), @IntermediateState(name = "seen", type = "BOOLEAN") })
@@ -67,7 +66,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(selected, driverContext);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final BytesRef scratch = new BytesRef();
private final IpArrayState internalState;
@@ -87,7 +86,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) {
}
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
internalState.toIntermediate(blocks, offset, selected, driverContext);
}
@@ -95,7 +95,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) {
return internalState.toValuesBlock(selected, driverContext);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
internalState.enableGroupIdTracking(seen);
}
@@ -105,7 +106,7 @@ public void close() {
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final BytesRef internalState;
private boolean seen;
@@ -121,7 +122,8 @@ public void add(BytesRef value) {
}
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState, 1);
blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1);
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java
index 329e798dcb3f0..d5ea72ed23e5e 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java
@@ -146,7 +146,8 @@ void add(int groupId, TDigestState other) {
}
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// We always enable.
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java
index bff8903fd3bec..5b48498d83294 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java
@@ -204,7 +204,8 @@ public void close() {
Releasables.close(states);
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// noop - we handle the null states inside `toIntermediate` and `evaluateFinal`
}
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java
index 252436ad9634f..e19d3107172e3 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java
@@ -17,7 +17,6 @@
import org.elasticsearch.compute.data.BooleanBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
/**
@@ -84,11 +83,12 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private boolean seenFalse;
private boolean seenTrue;
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
@@ -113,14 +113,15 @@ Block toBlock(BlockFactory blockFactory) {
public void close() {}
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final BitArray values;
private GroupingState(BigArrays bigArrays) {
values = new BitArray(1, bigArrays);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -155,7 +156,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) {
}
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we don't need to track which values have been seen because we don't do anything special for groups without values
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st
index 2581d3ebbf80b..a0b4ed8bd6337 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st
@@ -338,7 +338,8 @@ public class Rate$Type$Aggregator {
}
}
- void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seenGroupIds) {
// noop - we handle the null states inside `toIntermediate` and `evaluateFinal`
}
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st
index 18d573eea4a4c..761b70791e946 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st
@@ -28,7 +28,6 @@ import org.elasticsearch.compute.data.$Type$Block;
$endif$
import org.elasticsearch.compute.data.sort.$Name$BucketedSort;
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.sort.SortOrder;
@@ -99,7 +98,7 @@ $endif$
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
private final $Name$BucketedSort sort;
private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -120,7 +119,8 @@ $endif$
sort.merge(groupId, other.sort, otherGroupId);
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -128,7 +128,8 @@ $endif$
return sort.toBlock(blockFactory, selected);
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
@@ -138,7 +139,7 @@ $endif$
}
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
private final GroupingState internalState;
private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
@@ -153,7 +154,8 @@ $endif$
internalState.merge(0, other, 0);
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st
index 1cef234b2238f..3006af595be1f 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st
@@ -35,7 +35,6 @@ $if(long)$
import org.elasticsearch.compute.data.LongBlock;
$endif$
import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.core.Releasable;
$if(BytesRef)$
import org.elasticsearch.core.Releasables;
@@ -155,7 +154,7 @@ $endif$
return state.toBlock(driverContext.blockFactory(), selected);
}
- public static class SingleState implements Releasable {
+ public static class SingleState implements AggregatorState {
$if(BytesRef)$
private final BytesRefHash values;
@@ -171,7 +170,8 @@ $else$
$endif$
}
- void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory());
}
@@ -228,7 +228,7 @@ $endif$
* an {@code O(n^2)} operation for collection to support a {@code O(1)}
* collector operation. But at least it's fairly simple.
*/
- public static class GroupingState implements Releasable {
+ public static class GroupingState implements GroupingAggregatorState {
$if(long||double)$
private final LongLongHash values;
@@ -263,7 +263,8 @@ $elseif(int||float)$
$endif$
}
- void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+ @Override
+ public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
blocks[offset] = toBlock(driverContext.blockFactory(), selected);
}
@@ -324,7 +325,8 @@ $endif$
}
}
- void enableGroupIdTracking(SeenGroupIds seen) {
+ @Override
+ public void enableGroupIdTracking(SeenGroupIds seen) {
// we figure out seen values from nulls on the values block
}
diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java
index 47d927fda91b5..c3b07d069cf11 100644
--- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java
+++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java
@@ -260,7 +260,8 @@ boolean hasValue(int index) {
}
/** Needed for generated code that does null tracking, which we do not need because we use count */
- final void enableGroupIdTracking(SeenGroupIds ignore) {}
+ @Override
+ public final void enableGroupIdTracking(SeenGroupIds ignore) {}
private void ensureCapacity(int groupId) {
if (groupId >= xValues.size()) {