diff --git a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java index 444dbcc1b9e58..8096153459003 100644 --- a/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java +++ b/x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Aggregator.java @@ -37,11 +37,6 @@ * are ever collected. *

*

- * The generation code will also look for a method called {@code combineValueCount} - * which is called once per received block with a count of values. NOTE: We may - * not need this after we convert AVG into a composite operation. - *

- *

* The generation code also looks for the optional methods {@code combineIntermediate} * and {@code evaluateFinal} which are used to combine intermediate states and * produce the final output. If the first is missing then the generated code will @@ -63,4 +58,8 @@ */ Class[] warnExceptions() default {}; + /** + * If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function. + */ + boolean includeTimestamps() default false; } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index 46881bf337c89..d775a46109214 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -17,12 +17,12 @@ import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.gen.Methods.TypeMatcher; import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Objects; -import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -34,27 +34,24 @@ import javax.lang.model.util.Elements; import static java.util.stream.Collectors.joining; -import static org.elasticsearch.compute.gen.Methods.findMethod; -import static org.elasticsearch.compute.gen.Methods.findRequiredMethod; +import static org.elasticsearch.compute.gen.Methods.requireAnyArgs; +import static org.elasticsearch.compute.gen.Methods.requireAnyType; +import static org.elasticsearch.compute.gen.Methods.requireArgs; +import static org.elasticsearch.compute.gen.Methods.requireName; +import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements; +import static org.elasticsearch.compute.gen.Methods.requireStaticMethod; +import static org.elasticsearch.compute.gen.Methods.requireType; +import static org.elasticsearch.compute.gen.Methods.requireVoidType; import static org.elasticsearch.compute.gen.Methods.vectorAccessorName; import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION; import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS; import static org.elasticsearch.compute.gen.Types.BLOCK; import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY; -import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK; import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR; import static org.elasticsearch.compute.gen.Types.BYTES_REF; -import static org.elasticsearch.compute.gen.Types.BYTES_REF_BLOCK; -import static org.elasticsearch.compute.gen.Types.BYTES_REF_VECTOR; -import static org.elasticsearch.compute.gen.Types.DOUBLE_BLOCK; -import static org.elasticsearch.compute.gen.Types.DOUBLE_VECTOR; import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT; import static org.elasticsearch.compute.gen.Types.ELEMENT_TYPE; -import static org.elasticsearch.compute.gen.Types.FLOAT_BLOCK; -import static org.elasticsearch.compute.gen.Types.FLOAT_VECTOR; import static org.elasticsearch.compute.gen.Types.INTERMEDIATE_STATE_DESC; -import static org.elasticsearch.compute.gen.Types.INT_BLOCK; -import static org.elasticsearch.compute.gen.Types.INT_VECTOR; import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC; import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; import static org.elasticsearch.compute.gen.Types.LONG_BLOCK; @@ -78,46 +75,41 @@ public class AggregatorImplementer { private final List warnExceptions; private final ExecutableElement init; private final ExecutableElement combine; - private final ExecutableElement combineValueCount; - private final ExecutableElement combineIntermediate; - private final ExecutableElement evaluateFinal; + private final List createParameters; private final ClassName implementation; - private final TypeName stateType; - private final boolean stateTypeHasSeen; - private final boolean stateTypeHasFailed; - private final boolean valuesIsBytesRef; - private final boolean valuesIsArray; private final List intermediateState; - private final List createParameters; + private final boolean includeTimestampVector; + + private final AggregationState aggState; + private final AggregationParameter aggParam; public AggregatorImplementer( Elements elements, TypeElement declarationType, IntermediateState[] interStateAnno, - List warnExceptions + List warnExceptions, + boolean includeTimestampVector ) { this.declarationType = declarationType; this.warnExceptions = warnExceptions; - this.init = findRequiredMethod(declarationType, new String[] { "init", "initSingle" }, e -> true); - this.stateType = choseStateType(); - this.stateTypeHasSeen = elements.getAllMembers(elements.getTypeElement(stateType.toString())) - .stream() - .anyMatch(e -> e.toString().equals("seen()")); - this.stateTypeHasFailed = elements.getAllMembers(elements.getTypeElement(stateType.toString())) - .stream() - .anyMatch(e -> e.toString().equals("failed()")); + this.init = requireStaticMethod( + declarationType, + requirePrimitiveOrImplements(elements, Types.AGGREGATOR_STATE), + requireName("init", "initSingle"), + requireAnyArgs("") + ); + this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, false); + + this.combine = requireStaticMethod( + declarationType, + aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(), + requireName("combine"), + combineArgs(aggState, includeTimestampVector) + ); + // TODO support multiple parameters + this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType()); - this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> { - if (e.getParameters().size() == 0) { - return false; - } - TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType()); - return firstParamType.isPrimitive() || firstParamType.toString().equals(stateType.toString()); - }); - this.combineValueCount = findMethod(declarationType, "combineValueCount"); - this.combineIntermediate = findMethod(declarationType, "combineIntermediate"); - this.evaluateFinal = findMethod(declarationType, "evaluateFinal"); this.createParameters = init.getParameters() .stream() .map(Parameter::from) @@ -128,9 +120,20 @@ public AggregatorImplementer( elements.getPackageOf(declarationType).toString(), (declarationType.getSimpleName() + "AggregatorFunction").replace("AggregatorAggregator", "Aggregator") ); - this.valuesIsBytesRef = BYTES_REF.equals(valueTypeName()); - this.valuesIsArray = TypeKind.ARRAY.equals(valueTypeKind()); - intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList(); + this.intermediateState = Arrays.stream(interStateAnno).map(IntermediateStateDesc::newIntermediateStateDesc).toList(); + this.includeTimestampVector = includeTimestampVector; + } + + private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) { + if (includeTimestampVector) { + return requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.LONG), // @timestamp + requireAnyType("") + ); + } else { + return requireArgs(requireType(aggState.declaredType()), requireAnyType("")); + } } ClassName implementation() { @@ -141,68 +144,8 @@ List createParameters() { return createParameters; } - private TypeName choseStateType() { - TypeName initReturn = TypeName.get(init.getReturnType()); - if (false == initReturn.isPrimitive()) { - return initReturn; - } - String simpleName = firstUpper(initReturn.toString()); - if (warnExceptions.isEmpty()) { - return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "State"); - } - return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "FallibleState"); - } - - static String valueType(ExecutableElement init, ExecutableElement combine) { - if (combine != null) { - // If there's an explicit combine function it's final parameter is the type of the value. - return combine.getParameters().get(combine.getParameters().size() - 1).asType().toString(); - } - String initReturn = init.getReturnType().toString(); - switch (initReturn) { - case "double": - return "double"; - case "float": - return "float"; - case "long": - return "long"; - case "int": - return "int"; - case "boolean": - return "boolean"; - default: - throw new IllegalArgumentException("unknown primitive type for " + initReturn); - } - } - - static ClassName valueBlockType(ExecutableElement init, ExecutableElement combine) { - return switch (valueType(init, combine)) { - case "boolean" -> BOOLEAN_BLOCK; - case "double" -> DOUBLE_BLOCK; - case "float" -> FLOAT_BLOCK; - case "long" -> LONG_BLOCK; - case "int", "int[]" -> INT_BLOCK; - case "org.apache.lucene.util.BytesRef" -> BYTES_REF_BLOCK; - default -> throw new IllegalArgumentException("unknown block type for " + valueType(init, combine)); - }; - } - - static ClassName valueVectorType(ExecutableElement init, ExecutableElement combine) { - return switch (valueType(init, combine)) { - case "boolean" -> BOOLEAN_VECTOR; - case "double" -> DOUBLE_VECTOR; - case "float" -> FLOAT_VECTOR; - case "long" -> LONG_VECTOR; - case "int", "int[]" -> INT_VECTOR; - case "org.apache.lucene.util.BytesRef" -> BYTES_REF_VECTOR; - default -> throw new IllegalArgumentException("unknown vector type for " + valueType(init, combine)); - }; - } - - public static String firstUpper(String s) { - String head = s.toString().substring(0, 1).toUpperCase(Locale.ROOT); - String tail = s.toString().substring(1); - return head + tail; + public static String capitalize(String s) { + return Character.toUpperCase(s.charAt(0)) + s.substring(1); } public JavaFile sourceFile() { @@ -232,7 +175,7 @@ private TypeSpec type() { } builder.addField(DRIVER_CONTEXT, "driverContext", Modifier.PRIVATE, Modifier.FINAL); - builder.addField(stateType, "state", Modifier.PRIVATE, Modifier.FINAL); + builder.addField(aggState.type, "state", Modifier.PRIVATE, Modifier.FINAL); builder.addField(LIST_INTEGER, "channels", Modifier.PRIVATE, Modifier.FINAL); for (Parameter p : createParameters) { @@ -292,10 +235,10 @@ private CodeBlock callInit() { .map(p -> TypeName.get(p.asType()).equals(BIG_ARRAYS) ? "driverContext.bigArrays()" : p.getSimpleName().toString()) .collect(joining(", ")); CodeBlock.Builder builder = CodeBlock.builder(); - if (init.getReturnType().toString().equals(stateType.toString())) { - builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall); + if (aggState.declaredType().isPrimitive()) { + builder.add("new $T($T.$L($L))", aggState.type(), declarationType, init.getSimpleName(), initParametersCall); } else { - builder.add("new $T($T.$L($L))", stateType, declarationType, init.getSimpleName(), initParametersCall); + builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall); } return builder.build(); } @@ -320,7 +263,7 @@ private MethodSpec ctor() { } builder.addParameter(DRIVER_CONTEXT, "driverContext"); builder.addParameter(LIST_INTEGER, "channels"); - builder.addParameter(stateType, "state"); + builder.addParameter(aggState.type, "state"); if (warnExceptions.isEmpty() == false) { builder.addStatement("this.warnings = warnings"); @@ -352,7 +295,7 @@ private MethodSpec intermediateBlockCount() { private MethodSpec addRawInput() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).addParameter(PAGE, "page").addParameter(BOOLEAN_VECTOR, "mask"); - if (stateTypeHasFailed) { + if (aggState.hasFailed()) { builder.beginControlFlow("if (state.failed())"); builder.addStatement("return"); builder.endControlFlow(); @@ -366,43 +309,62 @@ private MethodSpec addRawInput() { builder.beginControlFlow("if (mask.allTrue())"); { builder.addComment("No masking"); - builder.addStatement("$T block = page.getBlock(channels.get(0))", valueBlockType(init, combine)); - builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine)); + builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type())); + builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type())); + if (includeTimestampVector) { + builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK); + builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR); + + builder.beginControlFlow("if (timestampsVector == null) "); + builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block"); + builder.endControlFlow(); + } builder.beginControlFlow("if (vector != null)"); - builder.addStatement("addRawVector(vector)"); + builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector)" : "addRawVector(vector)"); builder.nextControlFlow("else"); - builder.addStatement("addRawBlock(block)"); + builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector)" : "addRawBlock(block)"); builder.endControlFlow(); builder.addStatement("return"); } builder.endControlFlow(); builder.addComment("Some positions masked away, others kept"); - builder.addStatement("$T block = page.getBlock(channels.get(0))", valueBlockType(init, combine)); - builder.addStatement("$T vector = block.asVector()", valueVectorType(init, combine)); + builder.addStatement("$T block = page.getBlock(channels.get(0))", blockType(aggParam.type())); + builder.addStatement("$T vector = block.asVector()", vectorType(aggParam.type())); + if (includeTimestampVector) { + builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK); + builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR); + + builder.beginControlFlow("if (timestampsVector == null) "); + builder.addStatement("throw new IllegalStateException($S)", "expected @timestamp vector; but got a block"); + builder.endControlFlow(); + } builder.beginControlFlow("if (vector != null)"); - builder.addStatement("addRawVector(vector, mask)"); + builder.addStatement(includeTimestampVector ? "addRawVector(vector, timestampsVector, mask)" : "addRawVector(vector, mask)"); builder.nextControlFlow("else"); - builder.addStatement("addRawBlock(block, mask)"); + builder.addStatement(includeTimestampVector ? "addRawBlock(block, timestampsVector, mask)" : "addRawBlock(block, mask)"); builder.endControlFlow(); return builder.build(); } private MethodSpec addRawVector(boolean masked) { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawVector"); - builder.addModifiers(Modifier.PRIVATE).addParameter(valueVectorType(init, combine), "vector"); + builder.addModifiers(Modifier.PRIVATE).addParameter(vectorType(aggParam.type()), "vector"); + if (includeTimestampVector) { + builder.addParameter(LONG_VECTOR, "timestamps"); + } if (masked) { builder.addParameter(BOOLEAN_VECTOR, "mask"); } - if (valuesIsArray) { + if (aggParam.isArray()) { builder.addComment("This type does not support vectors because all values are multi-valued"); return builder.build(); } - if (stateTypeHasSeen) { + if (aggState.hasSeen()) { builder.addStatement("state.seen(true)"); } - if (valuesIsBytesRef) { + if (aggParam.isBytesRef()) { // Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); } @@ -415,20 +377,20 @@ private MethodSpec addRawVector(boolean masked) { combineRawInput(builder, "vector"); } builder.endControlFlow(); - if (combineValueCount != null) { - builder.addStatement("$T.combineValueCount(state, vector.getPositionCount())", declarationType); - } return builder.build(); } private MethodSpec addRawBlock(boolean masked) { MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawBlock"); - builder.addModifiers(Modifier.PRIVATE).addParameter(valueBlockType(init, combine), "block"); + builder.addModifiers(Modifier.PRIVATE).addParameter(blockType(aggParam.type()), "block"); + if (includeTimestampVector) { + builder.addParameter(LONG_VECTOR, "timestamps"); + } if (masked) { builder.addParameter(BOOLEAN_VECTOR, "mask"); } - if (valuesIsBytesRef) { + if (aggParam.isBytesRef()) { // Add bytes_ref scratch var that will only be used for bytes_ref blocks/vectors builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); } @@ -440,16 +402,16 @@ private MethodSpec addRawBlock(boolean masked) { builder.beginControlFlow("if (block.isNull(p))"); builder.addStatement("continue"); builder.endControlFlow(); - if (stateTypeHasSeen) { + if (aggState.hasSeen()) { builder.addStatement("state.seen(true)"); } builder.addStatement("int start = block.getFirstValueIndex(p)"); builder.addStatement("int end = start + block.getValueCount(p)"); - if (valuesIsArray) { - String arrayType = valueTypeString(); + if (aggParam.isArray()) { + String arrayType = aggParam.type().toString().replace("[]", ""); builder.addStatement("$L[] valuesArray = new $L[end - start]", arrayType, arrayType); builder.beginControlFlow("for (int i = start; i < end; i++)"); - builder.addStatement("valuesArray[i-start] = $L.get$L(i)", "block", firstUpper(arrayType)); + builder.addStatement("valuesArray[i-start] = $L.get$L(i)", "block", capitalize(arrayType)); builder.endControlFlow(); combineRawInputForArray(builder, "valuesArray"); } else { @@ -459,16 +421,13 @@ private MethodSpec addRawBlock(boolean masked) { } } builder.endControlFlow(); - if (combineValueCount != null) { - builder.addStatement("$T.combineValueCount(state, block.getTotalValueCount())", declarationType); - } return builder.build(); } private void combineRawInput(MethodSpec.Builder builder, String blockVariable) { TypeName returnType = TypeName.get(combine.getReturnType()); warningsBlock(builder, () -> { - if (valuesIsBytesRef) { + if (aggParam.isBytesRef()) { combineRawInputForBytesRef(builder, blockVariable); } else if (returnType.isPrimitive()) { combineRawInputForPrimitive(returnType, builder, blockVariable); @@ -480,33 +439,57 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable) { }); } - private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) { - builder.addStatement( - "state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))", - returnType, - declarationType, - returnType, - blockVariable, - firstUpper(combine.getParameters().get(1).asType().toString()) - ); + private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) { + // scratch is a BytesRef var that must have been defined before the iteration starts + if (includeTimestampVector) { + builder.addStatement("$T.combine(state, timestamps.getLong(i), $L.getBytesRef(i, scratch))", declarationType, blockVariable); + } else { + builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable); + } } - private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { - warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable)); + private void combineRawInputForPrimitive(TypeName returnType, MethodSpec.Builder builder, String blockVariable) { + if (includeTimestampVector) { + builder.addStatement( + "state.$TValue($T.combine(state.$TValue(), timestamps.getLong(i), $L.get$L(i)))", + returnType, + declarationType, + returnType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } else { + builder.addStatement( + "state.$TValue($T.combine(state.$TValue(), $L.get$L(i)))", + returnType, + declarationType, + returnType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } } private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable) { - builder.addStatement( - "$T.combine(state, $L.get$L(i))", - declarationType, - blockVariable, - firstUpper(combine.getParameters().get(1).asType().toString()) - ); + if (includeTimestampVector) { + builder.addStatement( + "$T.combine(state, timestamps.getLong(i), $L.get$L(i))", + declarationType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } else { + builder.addStatement( + "$T.combine(state, $L.get$L(i))", + declarationType, + blockVariable, + capitalize(combine.getParameters().get(1).asType().toString()) + ); + } } - private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable) { - // scratch is a BytesRef var that must have been defined before the iteration starts - builder.addStatement("$T.combine(state, $L.getBytesRef(i, scratch))", declarationType, blockVariable); + private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { + warningsBlock(builder, () -> builder.addStatement("$T.combine(state, $L)", declarationType, arrayVariable)); } private void warningsBlock(MethodSpec.Builder builder, Runnable block) { @@ -534,12 +517,7 @@ private MethodSpec addIntermediateInput() { interState.assignToVariable(builder, i); builder.addStatement("assert $L.getPositionCount() == 1", interState.name()); } - if (combineIntermediate != null) { - if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { - builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); - } - builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType); - } else if (hasPrimitiveState()) { + if (aggState.declaredType().isPrimitive()) { if (warnExceptions.isEmpty()) { assert intermediateState.size() == 2; assert intermediateState.get(1).name().equals("seen"); @@ -557,14 +535,36 @@ private MethodSpec addIntermediateInput() { } warningsBlock(builder, () -> { + var primitiveStateMethod = switch (aggState.declaredType().toString()) { + case "boolean" -> "booleanValue"; + case "int" -> "intValue"; + case "long" -> "longValue"; + case "double" -> "doubleValue"; + case "float" -> "floatValue"; + default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]"); + }; var state = intermediateState.get(0); var s = "state.$L($T.combine(state.$L(), " + state.name() + "." + vectorAccessorName(state.elementType()) + "(0)))"; - builder.addStatement(s, primitiveStateMethod(), declarationType, primitiveStateMethod()); + builder.addStatement(s, primitiveStateMethod, declarationType, primitiveStateMethod); builder.addStatement("state.seen(true)"); }); builder.endControlFlow(); } else { - throw new IllegalArgumentException("Don't know how to combine intermediate input. Define combineIntermediate"); + requireStaticMethod( + declarationType, + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.concat( + Stream.of(aggState.declaredType()), // aggState + intermediateState.stream().map(IntermediateStateDesc::combineArgType) // intermediate state + ).map(Methods::requireType).toArray(TypeMatcher[]::new) + ) + ); + if (intermediateState.stream().map(IntermediateStateDesc::elementType).anyMatch(n -> n.equals("BYTES_REF"))) { + builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); + } + builder.addStatement("$T.combineIntermediate(state, " + intermediateStateRowAccess() + ")", declarationType); } return builder.build(); } @@ -573,25 +573,6 @@ String intermediateStateRowAccess() { return intermediateState.stream().map(desc -> desc.access("0")).collect(joining(", ")); } - private String primitiveStateMethod() { - switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState": - return "booleanValue"; - case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState": - return "intValue"; - case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState": - return "longValue"; - case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState": - return "doubleValue"; - case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState": - return "floatValue"; - default: - throw new IllegalArgumentException( - "don't know how to fetch primitive values from " + stateType + ". define combineIntermediate." - ); - } - } - private MethodSpec evaluateIntermediate() { MethodSpec.Builder builder = MethodSpec.methodBuilder("evaluateIntermediate"); builder.addAnnotation(Override.class) @@ -610,45 +591,39 @@ private MethodSpec evaluateFinal() { .addParameter(BLOCK_ARRAY, "blocks") .addParameter(TypeName.INT, "offset") .addParameter(DRIVER_CONTEXT, "driverContext"); - if (stateTypeHasSeen || stateTypeHasFailed) { - var condition = Stream.of(stateTypeHasSeen ? "state.seen() == false" : null, stateTypeHasFailed ? "state.failed()" : null) - .filter(Objects::nonNull) - .collect(joining(" || ")); - builder.beginControlFlow("if ($L)", condition); + if (aggState.hasSeen() || aggState.hasFailed()) { + builder.beginControlFlow( + "if ($L)", + Stream.concat( + Stream.of("state.seen() == false").filter(c -> aggState.hasSeen()), + Stream.of("state.failed()").filter(c -> aggState.hasFailed()) + ).collect(joining(" || ")) + ); builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1)", BLOCK); builder.addStatement("return"); builder.endControlFlow(); } - if (evaluateFinal == null) { - primitiveStateToResult(builder); + if (aggState.declaredType().isPrimitive()) { + builder.addStatement(switch (aggState.declaredType().toString()) { + case "boolean" -> "blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)"; + case "int" -> "blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)"; + case "long" -> "blocks[offset] = driverContext.blockFactory().newConstantLongBlockWith(state.longValue(), 1)"; + case "double" -> "blocks[offset] = driverContext.blockFactory().newConstantDoubleBlockWith(state.doubleValue(), 1)"; + case "float" -> "blocks[offset] = driverContext.blockFactory().newConstantFloatBlockWith(state.floatValue(), 1)"; + default -> throw new IllegalArgumentException("Unexpected primitive type: [" + aggState.declaredType() + "]"); + }); } else { + requireStaticMethod( + declarationType, + requireType(BLOCK), + requireName("evaluateFinal"), + requireArgs(requireType(aggState.declaredType()), requireType(DRIVER_CONTEXT)) + ); builder.addStatement("blocks[offset] = $T.evaluateFinal(state, driverContext)", declarationType); } return builder.build(); } - private void primitiveStateToResult(MethodSpec.Builder builder) { - switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.BooleanFallibleState": - builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)"); - return; - case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.IntFallibleState": - builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)"); - return; - case "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.LongFallibleState": - builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantLongBlockWith(state.longValue(), 1)"); - return; - case "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.DoubleFallibleState": - builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantDoubleBlockWith(state.doubleValue(), 1)"); - return; - case "org.elasticsearch.compute.aggregation.FloatState", "org.elasticsearch.compute.aggregation.FloatFallibleState": - builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantFloatBlockWith(state.floatValue(), 1)"); - return; - default: - throw new IllegalArgumentException("don't know how to convert state to result: " + stateType); - } - } - private MethodSpec toStringMethod() { MethodSpec.Builder builder = MethodSpec.methodBuilder("toString"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).returns(String.class); @@ -667,14 +642,6 @@ private MethodSpec close() { return builder.build(); } - private static final Pattern PRIMITIVE_STATE_PATTERN = Pattern.compile( - "org.elasticsearch.compute.aggregation.(Boolean|Int|Long|Double|Float)(Fallible)?State" - ); - - private boolean hasPrimitiveState() { - return PRIMITIVE_STATE_PATTERN.matcher(stateType.toString()).matches(); - } - record IntermediateStateDesc(String name, String elementType, boolean block) { static IntermediateStateDesc newIntermediateStateDesc(IntermediateState state) { String type = state.type(); @@ -711,22 +678,57 @@ public void assignToVariable(MethodSpec.Builder builder, int offset) { builder.addStatement("$T $L = (($T) $L).asVector()", vectorType(elementType), name, blockType, name + "Uncast"); } } - } - private TypeMirror valueTypeMirror() { - return combine.getParameters().get(combine.getParameters().size() - 1).asType(); + public TypeName combineArgType() { + var type = Types.fromString(elementType); + return block ? blockType(type) : type; + } } - private TypeName valueTypeName() { - return TypeName.get(valueTypeMirror()); + /** + * This represents the type returned by init method used to keep aggregation state + * @param declaredType declared state type as returned by init method + * @param type actual type used (we have some predefined state types for primitive values) + */ + public record AggregationState(TypeName declaredType, TypeName type, boolean hasSeen, boolean hasFailed) { + + public static AggregationState create(Elements elements, TypeMirror mirror, boolean hasFailures, boolean isArray) { + var declaredType = TypeName.get(mirror); + var stateType = declaredType.isPrimitive() + ? ClassName.get("org.elasticsearch.compute.aggregation", primitiveStateStoreClassname(declaredType, hasFailures, isArray)) + : declaredType; + return new AggregationState( + declaredType, + stateType, + hasMethod(elements, stateType, "seen()"), + hasMethod(elements, stateType, "failed()") + ); + } + + private static String primitiveStateStoreClassname(TypeName declaredType, boolean hasFailures, boolean isArray) { + var name = capitalize(declaredType.toString()); + if (hasFailures) { + name += "Fallible"; + } + if (isArray) { + name += "Array"; + } + return name + "State"; + } } - private TypeKind valueTypeKind() { - return valueTypeMirror().getKind(); + public record AggregationParameter(TypeName type, boolean isArray) { + + public static AggregationParameter create(TypeMirror mirror) { + return new AggregationParameter(TypeName.get(mirror), Objects.equals(mirror.getKind(), TypeKind.ARRAY)); + } + + public boolean isBytesRef() { + return Objects.equals(type, BYTES_REF); + } } - private String valueTypeString() { - String valueTypeString = TypeName.get(valueTypeMirror()).toString(); - return valuesIsArray ? valueTypeString.substring(0, valueTypeString.length() - 2) : valueTypeString; + private static boolean hasMethod(Elements elements, TypeName type, String name) { + return elements.getAllMembers(elements.getTypeElement(type.toString())).stream().anyMatch(e -> e.toString().equals(name)); } } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java index 863db86eb934a..3ad2343ad1658 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorProcessor.java @@ -87,7 +87,13 @@ public boolean process(Set set, RoundEnvironment roundEnv ); if (aggClass.getAnnotation(Aggregator.class) != null) { IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value(); - implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes); + implementer = new AggregatorImplementer( + env.getElementUtils(), + aggClass, + intermediateState, + warnExceptionsTypes, + aggClass.getAnnotation(Aggregator.class).includeTimestamps() + ); write(aggClass, "aggregator", implementer.sourceFile(), env); } GroupingAggregatorImplementer groupingAggregatorImplementer = null; @@ -96,13 +102,12 @@ public boolean process(Set set, RoundEnvironment roundEnv if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) { intermediateState = aggClass.getAnnotation(Aggregator.class).value(); } - boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps(); groupingAggregatorImplementer = new GroupingAggregatorImplementer( env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes, - includeTimestamps + aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps() ); write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env); } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index 8224c73936b90..d2b6a0e011687 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -17,28 +17,35 @@ import org.elasticsearch.compute.ann.Aggregator; import org.elasticsearch.compute.ann.IntermediateState; +import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationParameter; +import org.elasticsearch.compute.gen.AggregatorImplementer.AggregationState; import java.util.Arrays; import java.util.List; import java.util.function.Consumer; -import java.util.regex.Pattern; +import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; -import javax.lang.model.type.TypeKind; import javax.lang.model.type.TypeMirror; import javax.lang.model.util.Elements; import static java.util.stream.Collectors.joining; -import static org.elasticsearch.compute.gen.AggregatorImplementer.firstUpper; -import static org.elasticsearch.compute.gen.AggregatorImplementer.valueBlockType; -import static org.elasticsearch.compute.gen.AggregatorImplementer.valueVectorType; -import static org.elasticsearch.compute.gen.Methods.findMethod; -import static org.elasticsearch.compute.gen.Methods.findRequiredMethod; +import static org.elasticsearch.compute.gen.AggregatorImplementer.capitalize; +import static org.elasticsearch.compute.gen.Methods.requireAnyArgs; +import static org.elasticsearch.compute.gen.Methods.requireAnyType; +import static org.elasticsearch.compute.gen.Methods.requireArgs; +import static org.elasticsearch.compute.gen.Methods.requireName; +import static org.elasticsearch.compute.gen.Methods.requirePrimitiveOrImplements; +import static org.elasticsearch.compute.gen.Methods.requireStaticMethod; +import static org.elasticsearch.compute.gen.Methods.requireType; +import static org.elasticsearch.compute.gen.Methods.requireVoidType; import static org.elasticsearch.compute.gen.Methods.vectorAccessorName; import static org.elasticsearch.compute.gen.Types.BIG_ARRAYS; +import static org.elasticsearch.compute.gen.Types.BLOCK; import static org.elasticsearch.compute.gen.Types.BLOCK_ARRAY; import static org.elasticsearch.compute.gen.Types.BYTES_REF; import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT; @@ -55,6 +62,8 @@ import static org.elasticsearch.compute.gen.Types.PAGE; import static org.elasticsearch.compute.gen.Types.SEEN_GROUP_IDS; import static org.elasticsearch.compute.gen.Types.WARNINGS; +import static org.elasticsearch.compute.gen.Types.blockType; +import static org.elasticsearch.compute.gen.Types.vectorType; /** * Implements "GroupingAggregationFunction" from a class containing static methods @@ -70,17 +79,14 @@ public class GroupingAggregatorImplementer { private final List warnExceptions; private final ExecutableElement init; private final ExecutableElement combine; - private final ExecutableElement combineStates; - private final ExecutableElement evaluateFinal; - private final ExecutableElement combineIntermediate; - private final TypeName stateType; - private final boolean valuesIsBytesRef; - private final boolean valuesIsArray; private final List createParameters; private final ClassName implementation; private final List intermediateState; private final boolean includeTimestampVector; + private final AggregationState aggState; + private final AggregationParameter aggParam; + public GroupingAggregatorImplementer( Elements elements, TypeElement declarationType, @@ -91,21 +97,23 @@ public GroupingAggregatorImplementer( this.declarationType = declarationType; this.warnExceptions = warnExceptions; - this.init = findRequiredMethod(declarationType, new String[] { "init", "initGrouping" }, e -> true); - this.stateType = choseStateType(); + this.init = requireStaticMethod( + declarationType, + requirePrimitiveOrImplements(elements, Types.GROUPING_AGGREGATOR_STATE), + requireName("init", "initGrouping"), + requireAnyArgs("") + ); + this.aggState = AggregationState.create(elements, init.getReturnType(), warnExceptions.isEmpty() == false, true); + + this.combine = requireStaticMethod( + declarationType, + aggState.declaredType().isPrimitive() ? requireType(aggState.declaredType()) : requireVoidType(), + requireName("combine"), + combineArgs(aggState, includeTimestampVector) + ); + // TODO support multiple parameters + this.aggParam = AggregationParameter.create(combine.getParameters().getLast().asType()); - this.combine = findRequiredMethod(declarationType, new String[] { "combine" }, e -> { - if (e.getParameters().size() == 0) { - return false; - } - TypeName firstParamType = TypeName.get(e.getParameters().get(0).asType()); - return firstParamType.isPrimitive() || firstParamType.toString().equals(stateType.toString()); - }); - this.combineStates = findMethod(declarationType, "combineStates"); - this.combineIntermediate = findMethod(declarationType, "combineIntermediate"); - this.evaluateFinal = findMethod(declarationType, "evaluateFinal"); - this.valuesIsBytesRef = BYTES_REF.equals(valueTypeName()); - this.valuesIsArray = TypeKind.ARRAY.equals(valueTypeKind()); this.createParameters = init.getParameters() .stream() .map(Parameter::from) @@ -117,12 +125,31 @@ public GroupingAggregatorImplementer( (declarationType.getSimpleName() + "GroupingAggregatorFunction").replace("AggregatorGroupingAggregator", "GroupingAggregator") ); - intermediateState = Arrays.stream(interStateAnno) + this.intermediateState = Arrays.stream(interStateAnno) .map(AggregatorImplementer.IntermediateStateDesc::newIntermediateStateDesc) .toList(); this.includeTimestampVector = includeTimestampVector; } + private static Methods.ArgumentMatcher combineArgs(AggregationState aggState, boolean includeTimestampVector) { + if (aggState.declaredType().isPrimitive()) { + return requireArgs(requireType(aggState.declaredType()), requireAnyType("")); + } else if (includeTimestampVector) { + return requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.INT), + requireType(TypeName.LONG), // @timestamp + requireAnyType("") + ); + } else { + return requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.INT), + requireAnyType("") + ); + } + } + public ClassName implementation() { return implementation; } @@ -131,18 +158,6 @@ List createParameters() { return createParameters; } - private TypeName choseStateType() { - TypeName initReturn = TypeName.get(init.getReturnType()); - if (false == initReturn.isPrimitive()) { - return initReturn; - } - String simpleName = firstUpper(initReturn.toString()); - if (warnExceptions.isEmpty()) { - return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "ArrayState"); - } - return ClassName.get("org.elasticsearch.compute.aggregation", simpleName + "FallibleArrayState"); - } - public JavaFile sourceFile() { JavaFile.Builder builder = JavaFile.builder(implementation.packageName(), type()); builder.addFileComment(""" @@ -164,7 +179,7 @@ private TypeSpec type() { .initializer(initInterState()) .build() ); - builder.addField(stateType, "state", Modifier.PRIVATE, Modifier.FINAL); + builder.addField(aggState.type(), "state", Modifier.PRIVATE, Modifier.FINAL); if (warnExceptions.isEmpty() == false) { builder.addField(WARNINGS, "warnings", Modifier.PRIVATE, Modifier.FINAL); } @@ -180,10 +195,10 @@ private TypeSpec type() { builder.addMethod(intermediateStateDesc()); builder.addMethod(intermediateBlockCount()); builder.addMethod(prepareProcessPage()); - builder.addMethod(addRawInputLoop(INT_VECTOR, valueBlockType(init, combine))); - builder.addMethod(addRawInputLoop(INT_VECTOR, valueVectorType(init, combine))); - builder.addMethod(addRawInputLoop(INT_BLOCK, valueBlockType(init, combine))); - builder.addMethod(addRawInputLoop(INT_BLOCK, valueVectorType(init, combine))); + builder.addMethod(addRawInputLoop(INT_VECTOR, blockType(aggParam.type()))); + builder.addMethod(addRawInputLoop(INT_VECTOR, vectorType(aggParam.type()))); + builder.addMethod(addRawInputLoop(INT_BLOCK, blockType(aggParam.type()))); + builder.addMethod(addRawInputLoop(INT_BLOCK, vectorType(aggParam.type()))); builder.addMethod(selectedMayContainUnseenGroups()); builder.addMethod(addIntermediateInput()); builder.addMethod(addIntermediateRowInput()); @@ -230,16 +245,16 @@ private CodeBlock callInit() { .map(p -> TypeName.get(p.asType()).equals(BIG_ARRAYS) ? "driverContext.bigArrays()" : p.getSimpleName().toString()) .collect(joining(", ")); CodeBlock.Builder builder = CodeBlock.builder(); - if (init.getReturnType().toString().equals(stateType.toString())) { - builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall); - } else { + if (aggState.declaredType().isPrimitive()) { builder.add( "new $T(driverContext.bigArrays(), $T.$L($L))", - stateType, + aggState.type(), declarationType, init.getSimpleName(), initParametersCall ); + } else { + builder.add("$T.$L($L)", declarationType, init.getSimpleName(), initParametersCall); } return builder.build(); } @@ -263,7 +278,7 @@ private MethodSpec ctor() { builder.addParameter(WARNINGS, "warnings"); } builder.addParameter(LIST_INTEGER, "channels"); - builder.addParameter(stateType, "state"); + builder.addParameter(aggState.type(), "state"); builder.addParameter(DRIVER_CONTEXT, "driverContext"); if (warnExceptions.isEmpty() == false) { builder.addStatement("this.warnings = warnings"); @@ -301,8 +316,8 @@ private MethodSpec prepareProcessPage() { builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).returns(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT); builder.addParameter(SEEN_GROUP_IDS, "seenGroupIds").addParameter(PAGE, "page"); - builder.addStatement("$T valuesBlock = page.getBlock(channels.get(0))", valueBlockType(init, combine)); - builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine)); + builder.addStatement("$T valuesBlock = page.getBlock(channels.get(0))", blockType(aggParam.type())); + builder.addStatement("$T valuesVector = valuesBlock.asVector()", vectorType(aggParam.type())); if (includeTimestampVector) { builder.addStatement("$T timestampsBlock = page.getBlock(channels.get(1))", LONG_BLOCK); builder.addStatement("$T timestampsVector = timestampsBlock.asVector()", LONG_VECTOR); @@ -355,18 +370,17 @@ private TypeSpec addInput(Consumer addBlock) { private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { boolean groupsIsBlock = groupsType.toString().endsWith("Block"); boolean valuesIsBlock = valuesType.toString().endsWith("Block"); - String methodName = "addRawInput"; - MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName); + MethodSpec.Builder builder = MethodSpec.methodBuilder("addRawInput"); builder.addModifiers(Modifier.PRIVATE); builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values"); if (includeTimestampVector) { builder.addParameter(LONG_VECTOR, "timestamps"); } - if (valuesIsBytesRef) { + if (aggParam.isBytesRef()) { // Add bytes_ref scratch var that will be used for bytes_ref blocks/vectors builder.addStatement("$T scratch = new $T()", BYTES_REF, BYTES_REF); } - if (valuesIsArray && valuesIsBlock == false) { + if (aggParam.isArray() && valuesIsBlock == false) { builder.addComment("This type does not support vectors because all values are multi-valued"); return builder.build(); } @@ -397,11 +411,11 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { builder.endControlFlow(); builder.addStatement("int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset)"); builder.addStatement("int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset)"); - if (valuesIsArray) { - String arrayType = valueTypeString(); + if (aggParam.isArray()) { + String arrayType = aggParam.type().toString().replace("[]", ""); builder.addStatement("$L[] valuesArray = new $L[valuesEnd - valuesStart]", arrayType, arrayType); builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)"); - builder.addStatement("valuesArray[v-valuesStart] = $L.get$L(v)", "values", firstUpper(arrayType)); + builder.addStatement("valuesArray[v-valuesStart] = $L.get$L(v)", "values", capitalize(arrayType)); builder.endControlFlow(); combineRawInputForArray(builder, "valuesArray"); } else { @@ -422,14 +436,12 @@ private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) { } private void combineRawInput(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - TypeName valueType = valueTypeName(); + TypeName valueType = aggParam.type(); TypeName returnType = TypeName.get(combine.getReturnType()); warningsBlock(builder, () -> { - if (valuesIsBytesRef) { + if (aggParam.isBytesRef()) { combineRawInputForBytesRef(builder, blockVariable, offsetVariable); - } else if (includeTimestampVector) { - combineRawInputWithTimestamp(builder, offsetVariable); } else if (valueType.isPrimitive() == false) { throw new IllegalArgumentException("second parameter to combine must be a primitive, array or BytesRef: " + valueType); } else if (returnType.isPrimitive()) { @@ -442,48 +454,75 @@ private void combineRawInput(MethodSpec.Builder builder, String blockVariable, S }); } - private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - builder.addStatement( - "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))", - declarationType, - blockVariable, - firstUpper(valueTypeName().toString()), - offsetVariable - ); + private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { + // scratch is a BytesRef var that must have been defined before the iteration starts + if (includeTimestampVector) { + if (offsetVariable.contains(" + ")) { + builder.addStatement("var valuePosition = $L", offsetVariable); + offsetVariable = "valuePosition"; + } + builder.addStatement( + "$T.combine(state, groupId, timestamps.getLong($L), $L.getBytesRef($L, scratch))", + declarationType, + offsetVariable, + blockVariable, + offsetVariable + ); + } else { + builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable); + } } - private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { - warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable)); + private void combineRawInputForPrimitive(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { + if (includeTimestampVector) { + if (offsetVariable.contains(" + ")) { + builder.addStatement("var valuePosition = $L", offsetVariable); + offsetVariable = "valuePosition"; + } + builder.addStatement( + "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))", + declarationType, + offsetVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); + } else { + builder.addStatement( + "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.get$L($L)))", + declarationType, + blockVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); + } } private void combineRawInputForVoid(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - builder.addStatement( - "$T.combine(state, groupId, $L.get$L($L))", - declarationType, - blockVariable, - firstUpper(valueTypeName().toString()), - offsetVariable - ); - } - - private void combineRawInputWithTimestamp(MethodSpec.Builder builder, String offsetVariable) { - String blockType = firstUpper(valueTypeName().toString()); - if (offsetVariable.contains(" + ")) { - builder.addStatement("var valuePosition = $L", offsetVariable); - offsetVariable = "valuePosition"; + if (includeTimestampVector) { + if (offsetVariable.contains(" + ")) { + builder.addStatement("var valuePosition = $L", offsetVariable); + offsetVariable = "valuePosition"; + } + builder.addStatement( + "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))", + declarationType, + offsetVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); + } else { + builder.addStatement( + "$T.combine(state, groupId, $L.get$L($L))", + declarationType, + blockVariable, + capitalize(aggParam.type().toString()), + offsetVariable + ); } - builder.addStatement( - "$T.combine(state, groupId, timestamps.getLong($L), values.get$L($L))", - declarationType, - offsetVariable, - blockType, - offsetVariable - ); } - private void combineRawInputForBytesRef(MethodSpec.Builder builder, String blockVariable, String offsetVariable) { - // scratch is a BytesRef var that must have been defined before the iteration starts - builder.addStatement("$T.combine(state, groupId, $L.getBytesRef($L, scratch))", declarationType, blockVariable, offsetVariable); + private void combineRawInputForArray(MethodSpec.Builder builder, String arrayVariable) { + warningsBlock(builder, () -> builder.addStatement("$T.combine(state, groupId, $L)", declarationType, arrayVariable)); } private void warningsBlock(MethodSpec.Builder builder, Runnable block) { @@ -539,7 +578,7 @@ private MethodSpec addIntermediateInput() { builder.beginControlFlow("for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++)"); { builder.addStatement("int groupId = groups.getInt(groupPosition)"); - if (hasPrimitiveState()) { + if (aggState.declaredType().isPrimitive()) { if (warnExceptions.isEmpty()) { assert intermediateState.size() == 2; assert intermediateState.get(1).name().equals("seen"); @@ -567,31 +606,33 @@ private MethodSpec addIntermediateInput() { }); builder.endControlFlow(); } else { - builder.addStatement("$T.combineIntermediate(state, groupId, " + intermediateStateRowAccess() + ")", declarationType); + var stateHasBlock = intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block); + requireStaticMethod( + declarationType, + requireVoidType(), + requireName("combineIntermediate"), + requireArgs( + Stream.of( + Stream.of(aggState.declaredType(), TypeName.INT), // aggState and groupId + intermediateState.stream().map(AggregatorImplementer.IntermediateStateDesc::combineArgType), + Stream.of(TypeName.INT).filter(p -> stateHasBlock) // position + ).flatMap(Function.identity()).map(Methods::requireType).toArray(Methods.TypeMatcher[]::new) + ) + ); + + builder.addStatement( + "$T.combineIntermediate(state, groupId, " + + intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")) + + (stateHasBlock ? ", groupPosition + positionOffset" : "") + + ")", + declarationType + ); } builder.endControlFlow(); } return builder.build(); } - String intermediateStateRowAccess() { - String rowAccess = intermediateState.stream().map(desc -> desc.access("groupPosition + positionOffset")).collect(joining(", ")); - if (intermediateState.stream().anyMatch(AggregatorImplementer.IntermediateStateDesc::block)) { - rowAccess += ", groupPosition + positionOffset"; - } - return rowAccess; - } - - private void combineStates(MethodSpec.Builder builder) { - if (combineStates == null) { - builder.beginControlFlow("if (inState.hasValue(position))"); - builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType); - builder.endControlFlow(); - return; - } - builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType); - } - private MethodSpec addIntermediateRowInput() { MethodSpec.Builder builder = MethodSpec.methodBuilder("addIntermediateRowInput"); builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC); @@ -601,9 +642,26 @@ private MethodSpec addIntermediateRowInput() { builder.addStatement("throw new IllegalArgumentException($S + getClass() + $S + input.getClass())", "expected ", "; got "); } builder.endControlFlow(); - builder.addStatement("$T inState = (($T) input).state", stateType, implementation); + builder.addStatement("$T inState = (($T) input).state", aggState.type(), implementation); builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS); - combineStates(builder); + if (aggState.declaredType().isPrimitive()) { + builder.beginControlFlow("if (inState.hasValue(position))"); + builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType); + builder.endControlFlow(); + } else { + requireStaticMethod( + declarationType, + requireVoidType(), + requireName("combineStates"), + requireArgs( + requireType(aggState.declaredType()), + requireType(TypeName.INT), + requireType(aggState.declaredType()), + requireType(TypeName.INT) + ) + ); + builder.addStatement("$T.combineStates(state, groupId, inState, position)", declarationType); + } return builder.build(); } @@ -627,9 +685,15 @@ private MethodSpec evaluateFinal() { .addParameter(INT_VECTOR, "selected") .addParameter(DRIVER_CONTEXT, "driverContext"); - if (evaluateFinal == null) { + if (aggState.declaredType().isPrimitive()) { builder.addStatement("blocks[offset] = state.toValuesBlock(selected, driverContext)"); } else { + requireStaticMethod( + declarationType, + requireType(BLOCK), + requireName("evaluateFinal"), + requireArgs(requireType(aggState.declaredType()), requireType(INT_VECTOR), requireType(DRIVER_CONTEXT)) + ); builder.addStatement("blocks[offset] = $T.evaluateFinal(state, selected, driverContext)", declarationType); } return builder.build(); @@ -652,32 +716,4 @@ private MethodSpec close() { builder.addStatement("state.close()"); return builder.build(); } - - private static final Pattern PRIMITIVE_STATE_PATTERN = Pattern.compile( - "org.elasticsearch.compute.aggregation.(Boolean|Int|Long|Double|Float)(Fallible)?ArrayState" - ); - - private boolean hasPrimitiveState() { - return PRIMITIVE_STATE_PATTERN.matcher(stateType.toString()).matches(); - } - - private TypeMirror valueTypeMirror() { - return combine.getParameters().get(combine.getParameters().size() - 1).asType(); - } - - private TypeName valueTypeName() { - return TypeName.get(valueTypeMirror()); - } - - private TypeKind valueTypeKind() { - return valueTypeMirror().getKind(); - } - - private String valueTypeString() { - String valueTypeString = TypeName.get(valueTypeMirror()).toString(); - if (valuesIsArray) { - valueTypeString = valueTypeString.substring(0, valueTypeString.length() - 2); - } - return valueTypeString; - } } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java index 6f98f1f797ab0..f2fa7b8084448 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Methods.java @@ -9,18 +9,22 @@ import com.squareup.javapoet.TypeName; -import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Set; import java.util.function.Predicate; +import java.util.stream.IntStream; +import java.util.stream.Stream; -import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; -import javax.lang.model.element.VariableElement; import javax.lang.model.type.DeclaredType; -import javax.lang.model.type.TypeMirror; +import javax.lang.model.type.TypeKind; import javax.lang.model.util.ElementFilter; +import javax.lang.model.util.Elements; +import static java.util.stream.Collectors.joining; import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK; import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK_BUILDER; import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR; @@ -49,30 +53,116 @@ * Finds declared methods for the code generator. */ public class Methods { - static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate filter) { - ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType)); - if (result == null) { - if (names.length == 1) { - throw new IllegalArgumentException(declarationType + "#" + names[0] + " is required"); - } - throw new IllegalArgumentException("one of " + declarationType + "#" + Arrays.toString(names) + " is required"); + + static ExecutableElement requireStaticMethod( + TypeElement declarationType, + TypeMatcher returnTypeMatcher, + NameMatcher nameMatcher, + ArgumentMatcher argumentMatcher + ) { + return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream()) + .filter(method -> method.getModifiers().contains(Modifier.STATIC)) + .filter(method -> nameMatcher.test(method.getSimpleName().toString())) + .filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType()))) + .filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList())) + .findFirst() + .orElseThrow(() -> { + var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: "; + var signatures = nameMatcher.names.stream() + .map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")") + .collect(joining(" or ")); + return new IllegalArgumentException(message + signatures); + }); + } + + static NameMatcher requireName(String... names) { + return new NameMatcher(Set.of(names)); + } + + static TypeMatcher requireVoidType() { + return new TypeMatcher(type -> Objects.equals(TypeName.VOID, type), "void"); + } + + static TypeMatcher requireAnyType(String description) { + return new TypeMatcher(type -> true, description); + } + + static TypeMatcher requirePrimitiveOrImplements(Elements elements, TypeName requiredInterface) { + return new TypeMatcher( + type -> type.isPrimitive() || isImplementing(elements, type, requiredInterface), + "[boolean|int|long|float|double|" + requiredInterface + "]" + ); + } + + static TypeMatcher requireType(TypeName requiredType) { + return new TypeMatcher(type -> Objects.equals(requiredType, type), requiredType.toString()); + } + + static ArgumentMatcher requireAnyArgs(String description) { + return new ArgumentMatcher(args -> true, description); + } + + static ArgumentMatcher requireArgs(TypeMatcher... argTypes) { + return new ArgumentMatcher( + args -> args.size() == argTypes.length && IntStream.range(0, argTypes.length).allMatch(i -> argTypes[i].test(args.get(i))), + Stream.of(argTypes).map(TypeMatcher::toString).collect(joining(", ")) + ); + } + + record NameMatcher(Set names) implements Predicate { + @Override + public boolean test(String name) { + return names.contains(name); } - return result; } - static ExecutableElement findMethod(TypeElement declarationType, String name) { - return findMethod(new String[] { name }, e -> true, declarationType, superClassOf(declarationType)); + record TypeMatcher(Predicate matcher, String description) implements Predicate { + @Override + public boolean test(TypeName typeName) { + return matcher.test(typeName); + } + + @Override + public String toString() { + return description; + } } - private static TypeElement superClassOf(TypeElement declarationType) { - TypeMirror superclass = declarationType.getSuperclass(); - if (superclass instanceof DeclaredType declaredType) { - Element superclassElement = declaredType.asElement(); - if (superclassElement instanceof TypeElement) { - return (TypeElement) superclassElement; - } + record ArgumentMatcher(Predicate> matcher, String description) implements Predicate> { + @Override + public boolean test(List typeName) { + return matcher.test(typeName); + } + + @Override + public String toString() { + return description; + } + } + + private static boolean isImplementing(Elements elements, TypeName type, TypeName requiredInterface) { + return allInterfacesOf(elements, type).anyMatch( + anInterface -> Objects.equals(anInterface.toString(), requiredInterface.toString()) + ); + } + + private static Stream allInterfacesOf(Elements elements, TypeName type) { + var typeElement = elements.getTypeElement(type.toString()); + var superType = Stream.of(typeElement.getSuperclass()).filter(sType -> sType.getKind() != TypeKind.NONE).map(TypeName::get); + var interfaces = typeElement.getInterfaces().stream().map(TypeName::get); + return Stream.concat( + superType.flatMap(sType -> allInterfacesOf(elements, sType)), + interfaces.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface))) + ); + } + + private static Stream typeAndSuperType(TypeElement declarationType) { + if (declarationType.getSuperclass() instanceof DeclaredType declaredType + && declaredType.asElement() instanceof TypeElement superType) { + return Stream.of(declarationType, superType); + } else { + return Stream.of(declarationType); } - return null; } static ExecutableElement findMethod(TypeElement declarationType, String[] names, Predicate filter) { @@ -95,16 +185,6 @@ static ExecutableElement findMethod(String[] names, Predicate return null; } - /** - * Returns the arguments of a method after applying a filter. - */ - static VariableElement[] findMethodArguments(ExecutableElement method, Predicate filter) { - if (method.getParameters().isEmpty()) { - return new VariableElement[0]; - } - return method.getParameters().stream().filter(filter).toArray(VariableElement[]::new); - } - /** * Returns the name of the method used to add {@code valueType} instances * to vector or block builders. diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java index 8b01d957f3bd2..35c42153f9ad6 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/Types.java @@ -15,9 +15,13 @@ import java.util.ArrayDeque; import java.util.Deque; import java.util.List; +import java.util.Map; +import java.util.stream.Stream; import javax.lang.model.type.TypeMirror; +import static java.util.stream.Collectors.toUnmodifiableMap; + /** * Types used by the code generator. */ @@ -75,26 +79,8 @@ public class Types { static final ClassName DOUBLE_VECTOR_FIXED_BUILDER = ClassName.get(DATA_PACKAGE, "DoubleVector", "FixedBuilder"); static final ClassName FLOAT_VECTOR_FIXED_BUILDER = ClassName.get(DATA_PACKAGE, "FloatVector", "FixedBuilder"); - static final ClassName BOOLEAN_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "BooleanArrayVector"); - static final ClassName BYTES_REF_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "BytesRefArrayVector"); - static final ClassName INT_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "IntArrayVector"); - static final ClassName LONG_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "LongArrayVector"); - static final ClassName DOUBLE_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "DoubleArrayVector"); - static final ClassName FLOAT_ARRAY_VECTOR = ClassName.get(DATA_PACKAGE, "FloatArrayVector"); - - static final ClassName BOOLEAN_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "BooleanArrayBlock"); - static final ClassName BYTES_REF_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "BytesRefArrayBlock"); - static final ClassName INT_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "IntArrayBlock"); - static final ClassName LONG_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "LongArrayBlock"); - static final ClassName DOUBLE_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "DoubleArrayBlock"); - static final ClassName FLOAT_ARRAY_BLOCK = ClassName.get(DATA_PACKAGE, "FloatArrayBlock"); - - static final ClassName BOOLEAN_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantBooleanVector"); - static final ClassName BYTES_REF_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantBytesRefVector"); - static final ClassName INT_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantIntVector"); - static final ClassName LONG_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantLongVector"); - static final ClassName DOUBLE_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantDoubleVector"); - static final ClassName FLOAT_CONSTANT_VECTOR = ClassName.get(DATA_PACKAGE, "ConstantFloatVector"); + static final ClassName AGGREGATOR_STATE = ClassName.get(AGGREGATION_PACKAGE, "AggregatorState"); + static final ClassName GROUPING_AGGREGATOR_STATE = ClassName.get(AGGREGATION_PACKAGE, "GroupingAggregatorState"); static final ClassName AGGREGATOR_FUNCTION = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunction"); static final ClassName AGGREGATOR_FUNCTION_SUPPLIER = ClassName.get(AGGREGATION_PACKAGE, "AggregatorFunctionSupplier"); @@ -138,89 +124,50 @@ public class Types { static final ClassName RELEASABLE = ClassName.get("org.elasticsearch.core", "Releasable"); static final ClassName RELEASABLES = ClassName.get("org.elasticsearch.core", "Releasables"); - static ClassName blockType(TypeName elementType) { - if (elementType.equals(TypeName.BOOLEAN)) { - return BOOLEAN_BLOCK; - } - if (elementType.equals(BYTES_REF)) { - return BYTES_REF_BLOCK; - } - if (elementType.equals(TypeName.INT)) { - return INT_BLOCK; - } - if (elementType.equals(TypeName.LONG)) { - return LONG_BLOCK; - } - if (elementType.equals(TypeName.DOUBLE)) { - return DOUBLE_BLOCK; + private record TypeDef(TypeName type, String alias, ClassName block, ClassName vector) { + + public static TypeDef of(TypeName type, String alias, String block, String vector) { + return new TypeDef(type, alias, ClassName.get(DATA_PACKAGE, block), ClassName.get(DATA_PACKAGE, vector)); } - throw new IllegalArgumentException("unknown block type for [" + elementType + "]"); + } + + private static final Map TYPES = Stream.of( + TypeDef.of(TypeName.BOOLEAN, "BOOLEAN", "BooleanBlock", "BooleanVector"), + TypeDef.of(TypeName.INT, "INT", "IntBlock", "IntVector"), + TypeDef.of(TypeName.LONG, "LONG", "LongBlock", "LongVector"), + TypeDef.of(TypeName.FLOAT, "FLOAT", "FloatBlock", "FloatVector"), + TypeDef.of(TypeName.DOUBLE, "DOUBLE", "DoubleBlock", "DoubleVector"), + TypeDef.of(BYTES_REF, "BYTES_REF", "BytesRefBlock", "BytesRefVector") + ) + .flatMap(def -> Stream.of(def.type.toString(), def.type + "[]", def.alias).map(alias -> Map.entry(alias, def))) + .collect(toUnmodifiableMap(Map.Entry::getKey, Map.Entry::getValue)); + + private static TypeDef findRequired(String name, String kind) { + TypeDef typeDef = TYPES.get(name); + if (typeDef == null) { + throw new IllegalArgumentException("unknown " + kind + " type [" + name + "]"); + } + return typeDef; + } + + static TypeName fromString(String type) { + return findRequired(type, "plain").type; + } + + static ClassName blockType(TypeName elementType) { + return blockType(elementType.toString()); } static ClassName blockType(String elementType) { - if (elementType.equalsIgnoreCase(TypeName.BOOLEAN.toString())) { - return BOOLEAN_BLOCK; - } - if (elementType.equalsIgnoreCase("BYTES_REF")) { - return BYTES_REF_BLOCK; - } - if (elementType.equalsIgnoreCase(TypeName.INT.toString())) { - return INT_BLOCK; - } - if (elementType.equalsIgnoreCase(TypeName.LONG.toString())) { - return LONG_BLOCK; - } - if (elementType.equalsIgnoreCase(TypeName.DOUBLE.toString())) { - return DOUBLE_BLOCK; - } - if (elementType.equalsIgnoreCase(TypeName.FLOAT.toString())) { - return FLOAT_BLOCK; - } - throw new IllegalArgumentException("unknown vector type for [" + elementType + "]"); + return findRequired(elementType, "block").block; } static ClassName vectorType(TypeName elementType) { - if (elementType.equals(TypeName.BOOLEAN)) { - return BOOLEAN_VECTOR; - } - if (elementType.equals(BYTES_REF)) { - return BYTES_REF_VECTOR; - } - if (elementType.equals(TypeName.INT)) { - return INT_VECTOR; - } - if (elementType.equals(TypeName.LONG)) { - return LONG_VECTOR; - } - if (elementType.equals(TypeName.DOUBLE)) { - return DOUBLE_VECTOR; - } - if (elementType.equals(TypeName.FLOAT)) { - return FLOAT_VECTOR; - } - throw new IllegalArgumentException("unknown vector type for [" + elementType + "]"); + return vectorType(elementType.toString()); } static ClassName vectorType(String elementType) { - if (elementType.equalsIgnoreCase(TypeName.BOOLEAN.toString())) { - return BOOLEAN_VECTOR; - } - if (elementType.equalsIgnoreCase("BYTES_REF")) { - return BYTES_REF_VECTOR; - } - if (elementType.equalsIgnoreCase(TypeName.INT.toString())) { - return INT_VECTOR; - } - if (elementType.equalsIgnoreCase(TypeName.LONG.toString())) { - return LONG_VECTOR; - } - if (elementType.equalsIgnoreCase(TypeName.DOUBLE.toString())) { - return DOUBLE_VECTOR; - } - if (elementType.equalsIgnoreCase(TypeName.FLOAT.toString())) { - return FLOAT_VECTOR; - } - throw new IllegalArgumentException("unknown vector type for [" + elementType + "]"); + return findRequired(elementType, "vector").vector; } static ClassName builderType(TypeName resultType) { @@ -282,63 +229,6 @@ static ClassName vectorFixedBuilderType(TypeName elementType) { throw new IllegalArgumentException("unknown vector fixed builder type for [" + elementType + "]"); } - static ClassName arrayVectorType(TypeName elementType) { - if (elementType.equals(TypeName.BOOLEAN)) { - return BOOLEAN_ARRAY_VECTOR; - } - if (elementType.equals(BYTES_REF)) { - return BYTES_REF_ARRAY_VECTOR; - } - if (elementType.equals(TypeName.INT)) { - return INT_ARRAY_VECTOR; - } - if (elementType.equals(TypeName.LONG)) { - return LONG_ARRAY_VECTOR; - } - if (elementType.equals(TypeName.DOUBLE)) { - return DOUBLE_ARRAY_VECTOR; - } - throw new IllegalArgumentException("unknown vector type for [" + elementType + "]"); - } - - static ClassName arrayBlockType(TypeName elementType) { - if (elementType.equals(TypeName.BOOLEAN)) { - return BOOLEAN_ARRAY_BLOCK; - } - if (elementType.equals(BYTES_REF)) { - return BYTES_REF_ARRAY_BLOCK; - } - if (elementType.equals(TypeName.INT)) { - return INT_ARRAY_BLOCK; - } - if (elementType.equals(TypeName.LONG)) { - return LONG_ARRAY_BLOCK; - } - if (elementType.equals(TypeName.DOUBLE)) { - return DOUBLE_ARRAY_BLOCK; - } - throw new IllegalArgumentException("unknown vector type for [" + elementType + "]"); - } - - static ClassName constantVectorType(TypeName elementType) { - if (elementType.equals(TypeName.BOOLEAN)) { - return BOOLEAN_CONSTANT_VECTOR; - } - if (elementType.equals(BYTES_REF)) { - return BYTES_REF_CONSTANT_VECTOR; - } - if (elementType.equals(TypeName.INT)) { - return INT_CONSTANT_VECTOR; - } - if (elementType.equals(TypeName.LONG)) { - return LONG_CONSTANT_VECTOR; - } - if (elementType.equals(TypeName.DOUBLE)) { - return DOUBLE_CONSTANT_VECTOR; - } - throw new IllegalArgumentException("unknown vector type for [" + elementType + "]"); - } - static TypeName elementType(TypeName t) { if (t.equals(BOOLEAN_BLOCK) || t.equals(BOOLEAN_VECTOR) || t.equals(BOOLEAN_BLOCK_BUILDER)) { return TypeName.BOOLEAN; diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java index cbd20f15c6511..deec1ef04f623 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java @@ -333,7 +333,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) { } } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java index b50b125d98331..94ad5254bc723 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateFloatAggregator.java @@ -334,7 +334,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) { } } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java index 01c3e3d7fb8e7..011291dd08c52 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java @@ -334,7 +334,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) { } } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java index c84985b703aed..9ccb5d3bd1b1a 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java @@ -333,7 +333,8 @@ Block evaluateFinal(IntVector selected, BlockFactory blockFactory) { } } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` } } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java index 32391c4827303..a2e86b3b09340 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.sort.BooleanBucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final BooleanBucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) { sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -107,7 +108,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -122,7 +123,8 @@ public void merge(GroupingState other) { internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java index c9b0e679b3e64..0a965899c0775 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopBytesRefAggregator.java @@ -19,7 +19,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.sort.BytesRefBucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -78,7 +77,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final BytesRefBucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -95,7 +94,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) { sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -103,7 +103,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -113,7 +114,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -128,7 +129,8 @@ public void merge(GroupingState other) { internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java index d9a7a302f07c8..6a20ed99bc236 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopDoubleAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.sort.DoubleBucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final DoubleBucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) { sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -107,7 +108,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -122,7 +123,8 @@ public void merge(GroupingState other) { internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java index 8b65261e10f46..cf6ad0f9017de 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopFloatAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.sort.FloatBucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final FloatBucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) { sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -107,7 +108,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -122,7 +123,8 @@ public void merge(GroupingState other) { internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java index 5c6b79f710af5..f4ac83c438063 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIntAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.sort.IntBucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final IntBucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) { sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -107,7 +108,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -122,7 +123,8 @@ public void merge(GroupingState other) { internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java index 219f7385b56df..292dd539edeb5 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopIpAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.sort.IpBucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -77,7 +76,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final IpBucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -92,7 +91,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) { sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -100,7 +100,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -110,7 +111,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -125,7 +126,8 @@ public void merge(GroupingState other) { internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java index 44cef8df7257b..c5af92956bec1 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/TopLongAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.data.sort.LongBucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -74,7 +73,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final LongBucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -89,7 +88,8 @@ public void merge(int groupId, GroupingState other, int otherGroupId) { sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -97,7 +97,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -107,7 +108,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -122,7 +123,8 @@ public void merge(GroupingState other) { internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java index bd77bd7ff1e46..ad0ab2f7189f6 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesBytesRefAggregator.java @@ -20,7 +20,6 @@ import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; /** @@ -83,14 +82,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final BytesRefHash values; private SingleState(BigArrays bigArrays) { values = new BytesRefHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } @@ -125,7 +125,7 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final LongLongHash values; private final BytesRefHash bytes; @@ -146,7 +146,8 @@ private GroupingState(BigArrays bigArrays) { } } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -190,7 +191,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java index a8409367bc090..271d7120092ca 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesDoubleAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.DoubleBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; /** * Aggregates field values for double. @@ -77,14 +76,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final LongHash values; private SingleState(BigArrays bigArrays) { values = new LongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } @@ -118,14 +118,15 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final LongLongHash values; private GroupingState(BigArrays bigArrays) { values = new LongLongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -168,7 +169,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java index f9e5e1b7b283a..b44cad807fba2 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesFloatAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.FloatBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; /** * Aggregates field values for float. @@ -82,14 +81,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final LongHash values; private SingleState(BigArrays bigArrays) { values = new LongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } @@ -123,14 +123,15 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final LongHash values; private GroupingState(BigArrays bigArrays) { values = new LongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -175,7 +176,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java index 2420dcee70712..4d0c518245694 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesIntAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; /** * Aggregates field values for int. @@ -82,14 +81,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final LongHash values; private SingleState(BigArrays bigArrays) { values = new LongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } @@ -123,14 +123,15 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final LongHash values; private GroupingState(BigArrays bigArrays) { values = new LongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -175,7 +176,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java index 4938b8f15edb0..5471c90147ec4 100644 --- a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/ValuesLongAggregator.java @@ -18,7 +18,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LongBlock; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; /** * Aggregates field values for long. @@ -77,14 +76,15 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final LongHash values; private SingleState(BigArrays bigArrays) { values = new LongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } @@ -118,14 +118,15 @@ public void close() { * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final LongLongHash values; private GroupingState(BigArrays bigArrays) { values = new LongLongHash(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -168,7 +169,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java index 5fa1394e8cf96..9886e0c1af306 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/AbstractArrayState.java @@ -37,6 +37,7 @@ public boolean hasValue(int groupId) { * idempotent and fast if already tracking so it's safe to, say, call it once * for every block of values that arrives containing {@code null}. */ + @Override public final void enableGroupIdTracking(SeenGroupIds seenGroupIds) { if (seen == null) { seen = seenGroupIds.seenGroupIds(bigArrays); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java index eb0a992c8610f..18b92c5447076 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/BytesRefArrayState.java @@ -138,7 +138,8 @@ boolean hasValue(int groupId) { * stores a flag to know if optimizations can be made. *

*/ - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { this.groupIdTrackingEnabled = true; } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java index 7c644342598dc..0e65164665808 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/GroupingAggregatorState.java @@ -17,4 +17,5 @@ public interface GroupingAggregatorState extends Releasable { /** Extracts an intermediate view of the contents of this state. */ void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext); + void enableGroupIdTracking(SeenGroupIds seenGroupIds); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java index 3d8d04d7dc7e3..64a970c2acc07 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/HllStates.java @@ -138,7 +138,8 @@ static class GroupingState implements GroupingAggregatorState { this.hll = new HyperLogLogPlusPlus(HyperLogLogPlusPlus.precisionFromThreshold(precision), bigArrays, 1); } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // Nothing to do } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java index 144214f93571e..049642c350917 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBytesRefAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; /** @@ -71,7 +70,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(selected, driverContext); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final BytesRefArrayState internalState; private GroupingState(BigArrays bigArrays, CircuitBreaker breaker) { @@ -90,7 +89,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) { } } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); } @@ -98,7 +98,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) { return internalState.toValuesBlock(selected, driverContext); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { internalState.enableGroupIdTracking(seen); } @@ -108,7 +109,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final BreakingBytesRefBuilder internalState; private boolean seen; @@ -128,7 +129,8 @@ public void add(BytesRef value) { } } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState.bytesRefView(), 1); blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java index 1ddce7674ae7b..43b4a4a2fe0a1 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxIpAggregator.java @@ -15,7 +15,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @Aggregator({ @IntermediateState(name = "max", type = "BYTES_REF"), @IntermediateState(name = "seen", type = "BOOLEAN") }) @@ -67,7 +66,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(selected, driverContext); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final BytesRef scratch = new BytesRef(); private final IpArrayState internalState; @@ -87,7 +86,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) { } } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); } @@ -95,7 +95,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) { return internalState.toValuesBlock(selected, driverContext); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { internalState.enableGroupIdTracking(seen); } @@ -105,7 +106,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final BytesRef internalState; private boolean seen; @@ -121,7 +122,8 @@ public void add(BytesRef value) { } } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState, 1); blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java index 830900702a371..677b38a9af3a7 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBytesRefAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.BreakingBytesRefBuilder; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; /** @@ -71,7 +70,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(selected, driverContext); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final BytesRefArrayState internalState; private GroupingState(BigArrays bigArrays, CircuitBreaker breaker) { @@ -90,7 +89,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) { } } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); } @@ -98,7 +98,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) { return internalState.toValuesBlock(selected, driverContext); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { internalState.enableGroupIdTracking(seen); } @@ -108,7 +109,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final BreakingBytesRefBuilder internalState; private boolean seen; @@ -128,7 +129,8 @@ public void add(BytesRef value) { } } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState.bytesRefView(), 1); blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java index 8313756851c1f..c4ee93db89cf8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinIpAggregator.java @@ -15,7 +15,6 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; @Aggregator({ @IntermediateState(name = "max", type = "BYTES_REF"), @IntermediateState(name = "seen", type = "BOOLEAN") }) @@ -67,7 +66,7 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(selected, driverContext); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final BytesRef scratch = new BytesRef(); private final IpArrayState internalState; @@ -87,7 +86,8 @@ public void combine(int groupId, GroupingState otherState, int otherGroupId) { } } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { internalState.toIntermediate(blocks, offset, selected, driverContext); } @@ -95,7 +95,8 @@ Block toBlock(IntVector selected, DriverContext driverContext) { return internalState.toValuesBlock(selected, driverContext); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { internalState.enableGroupIdTracking(seen); } @@ -105,7 +106,7 @@ public void close() { } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final BytesRef internalState; private boolean seen; @@ -121,7 +122,8 @@ public void add(BytesRef value) { } } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = driverContext.blockFactory().newConstantBytesRefBlockWith(internalState, 1); blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java index 329e798dcb3f0..d5ea72ed23e5e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/QuantileStates.java @@ -146,7 +146,8 @@ void add(int groupId, TDigestState other) { } } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // We always enable. } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java index bff8903fd3bec..5b48498d83294 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/StdDevStates.java @@ -204,7 +204,8 @@ public void close() { Releasables.close(states); } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java index 252436ad9634f..e19d3107172e3 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/ValuesBooleanAggregator.java @@ -17,7 +17,6 @@ import org.elasticsearch.compute.data.BooleanBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; /** @@ -84,11 +83,12 @@ public static Block evaluateFinal(GroupingState state, IntVector selected, Drive return state.toBlock(driverContext.blockFactory(), selected); } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private boolean seenFalse; private boolean seenTrue; - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } @@ -113,14 +113,15 @@ Block toBlock(BlockFactory blockFactory) { public void close() {} } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final BitArray values; private GroupingState(BigArrays bigArrays) { values = new BitArray(1, bigArrays); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -155,7 +156,8 @@ Block toBlock(BlockFactory blockFactory, IntVector selected) { } } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we don't need to track which values have been seen because we don't do anything special for groups without values } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st index 2581d3ebbf80b..a0b4ed8bd6337 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st @@ -338,7 +338,8 @@ public class Rate$Type$Aggregator { } } - void enableGroupIdTracking(SeenGroupIds seenGroupIds) { + @Override + public void enableGroupIdTracking(SeenGroupIds seenGroupIds) { // noop - we handle the null states inside `toIntermediate` and `evaluateFinal` } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st index 18d573eea4a4c..761b70791e946 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-TopAggregator.java.st @@ -28,7 +28,6 @@ import org.elasticsearch.compute.data.$Type$Block; $endif$ import org.elasticsearch.compute.data.sort.$Name$BucketedSort; import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.sort.SortOrder; @@ -99,7 +98,7 @@ $endif$ return state.toBlock(driverContext.blockFactory(), selected); } - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { private final $Name$BucketedSort sort; private GroupingState(BigArrays bigArrays, int limit, boolean ascending) { @@ -120,7 +119,8 @@ $endif$ sort.merge(groupId, other.sort, otherGroupId); } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -128,7 +128,8 @@ $endif$ return sort.toBlock(blockFactory, selected); } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } @@ -138,7 +139,7 @@ $endif$ } } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { private final GroupingState internalState; private SingleState(BigArrays bigArrays, int limit, boolean ascending) { @@ -153,7 +154,8 @@ $endif$ internalState.merge(0, other, 0); } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st index 1cef234b2238f..3006af595be1f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ValuesAggregator.java.st @@ -35,7 +35,6 @@ $if(long)$ import org.elasticsearch.compute.data.LongBlock; $endif$ import org.elasticsearch.compute.operator.DriverContext; -import org.elasticsearch.core.Releasable; $if(BytesRef)$ import org.elasticsearch.core.Releasables; @@ -155,7 +154,7 @@ $endif$ return state.toBlock(driverContext.blockFactory(), selected); } - public static class SingleState implements Releasable { + public static class SingleState implements AggregatorState { $if(BytesRef)$ private final BytesRefHash values; @@ -171,7 +170,8 @@ $else$ $endif$ } - void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory()); } @@ -228,7 +228,7 @@ $endif$ * an {@code O(n^2)} operation for collection to support a {@code O(1)} * collector operation. But at least it's fairly simple. */ - public static class GroupingState implements Releasable { + public static class GroupingState implements GroupingAggregatorState { $if(long||double)$ private final LongLongHash values; @@ -263,7 +263,8 @@ $elseif(int||float)$ $endif$ } - void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { + @Override + public void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) { blocks[offset] = toBlock(driverContext.blockFactory(), selected); } @@ -324,7 +325,8 @@ $endif$ } } - void enableGroupIdTracking(SeenGroupIds seen) { + @Override + public void enableGroupIdTracking(SeenGroupIds seen) { // we figure out seen values from nulls on the values block } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java index 47d927fda91b5..c3b07d069cf11 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/spatial/CentroidPointAggregator.java @@ -260,7 +260,8 @@ boolean hasValue(int index) { } /** Needed for generated code that does null tracking, which we do not need because we use count */ - final void enableGroupIdTracking(SeenGroupIds ignore) {} + @Override + public final void enableGroupIdTracking(SeenGroupIds ignore) {} private void ensureCapacity(int groupId) { if (groupId >= xValues.size()) {