Skip to content

Commit

Permalink
[9.0] backport various aggs code gen improvements (#122360)
Browse files Browse the repository at this point in the history
  • Loading branch information
idegtiarenko authored Feb 14, 2025
1 parent 3202262 commit c2e632a
Show file tree
Hide file tree
Showing 37 changed files with 754 additions and 692 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@
* are ever collected.
* </p>
* <p>
* The generation code will also look for a method called {@code combineValueCount}
* which is called once per received block with a count of values. NOTE: We may
* not need this after we convert AVG into a composite operation.
* </p>
* <p>
* The generation code also looks for the optional methods {@code combineIntermediate}
* and {@code evaluateFinal} which are used to combine intermediate states and
* produce the final output. If the first is missing then the generated code will
Expand All @@ -63,4 +58,8 @@
*/
Class<? extends Exception>[] warnExceptions() default {};

/**
* If {@code true} then the @timestamp LongVector will be appended to the input blocks of the aggregation function.
*/
boolean includeTimestamps() default false;
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
);
if (aggClass.getAnnotation(Aggregator.class) != null) {
IntermediateState[] intermediateState = aggClass.getAnnotation(Aggregator.class).value();
implementer = new AggregatorImplementer(env.getElementUtils(), aggClass, intermediateState, warnExceptionsTypes);
implementer = new AggregatorImplementer(
env.getElementUtils(),
aggClass,
intermediateState,
warnExceptionsTypes,
aggClass.getAnnotation(Aggregator.class).includeTimestamps()
);
write(aggClass, "aggregator", implementer.sourceFile(), env);
}
GroupingAggregatorImplementer groupingAggregatorImplementer = null;
Expand All @@ -96,13 +102,12 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
if (intermediateState.length == 0 && aggClass.getAnnotation(Aggregator.class) != null) {
intermediateState = aggClass.getAnnotation(Aggregator.class).value();
}
boolean includeTimestamps = aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps();
groupingAggregatorImplementer = new GroupingAggregatorImplementer(
env.getElementUtils(),
aggClass,
intermediateState,
warnExceptionsTypes,
includeTimestamps
aggClass.getAnnotation(GroupingAggregator.class).includeTimestamps()
);
write(aggClass, "grouping aggregator", groupingAggregatorImplementer.sourceFile(), env);
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@

import com.squareup.javapoet.TypeName;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import javax.lang.model.element.Element;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.type.TypeKind;
import javax.lang.model.util.ElementFilter;
import javax.lang.model.util.Elements;

import static java.util.stream.Collectors.joining;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_BLOCK_BUILDER;
import static org.elasticsearch.compute.gen.Types.BOOLEAN_VECTOR;
Expand Down Expand Up @@ -49,30 +53,116 @@
* Finds declared methods for the code generator.
*/
public class Methods {
static ExecutableElement findRequiredMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
ExecutableElement result = findMethod(names, filter, declarationType, superClassOf(declarationType));
if (result == null) {
if (names.length == 1) {
throw new IllegalArgumentException(declarationType + "#" + names[0] + " is required");
}
throw new IllegalArgumentException("one of " + declarationType + "#" + Arrays.toString(names) + " is required");

static ExecutableElement requireStaticMethod(
TypeElement declarationType,
TypeMatcher returnTypeMatcher,
NameMatcher nameMatcher,
ArgumentMatcher argumentMatcher
) {
return typeAndSuperType(declarationType).flatMap(type -> ElementFilter.methodsIn(type.getEnclosedElements()).stream())
.filter(method -> method.getModifiers().contains(Modifier.STATIC))
.filter(method -> nameMatcher.test(method.getSimpleName().toString()))
.filter(method -> returnTypeMatcher.test(TypeName.get(method.getReturnType())))
.filter(method -> argumentMatcher.test(method.getParameters().stream().map(it -> TypeName.get(it.asType())).toList()))
.findFirst()
.orElseThrow(() -> {
var message = nameMatcher.names.size() == 1 ? "Requires method: " : "Requires one of methods: ";
var signatures = nameMatcher.names.stream()
.map(name -> "public static " + returnTypeMatcher + " " + declarationType + "#" + name + "(" + argumentMatcher + ")")
.collect(joining(" or "));
return new IllegalArgumentException(message + signatures);
});
}

static NameMatcher requireName(String... names) {
return new NameMatcher(Set.of(names));
}

static TypeMatcher requireVoidType() {
return new TypeMatcher(type -> Objects.equals(TypeName.VOID, type), "void");
}

static TypeMatcher requireAnyType(String description) {
return new TypeMatcher(type -> true, description);
}

static TypeMatcher requirePrimitiveOrImplements(Elements elements, TypeName requiredInterface) {
return new TypeMatcher(
type -> type.isPrimitive() || isImplementing(elements, type, requiredInterface),
"[boolean|int|long|float|double|" + requiredInterface + "]"
);
}

static TypeMatcher requireType(TypeName requiredType) {
return new TypeMatcher(type -> Objects.equals(requiredType, type), requiredType.toString());
}

static ArgumentMatcher requireAnyArgs(String description) {
return new ArgumentMatcher(args -> true, description);
}

static ArgumentMatcher requireArgs(TypeMatcher... argTypes) {
return new ArgumentMatcher(
args -> args.size() == argTypes.length && IntStream.range(0, argTypes.length).allMatch(i -> argTypes[i].test(args.get(i))),
Stream.of(argTypes).map(TypeMatcher::toString).collect(joining(", "))
);
}

record NameMatcher(Set<String> names) implements Predicate<String> {
@Override
public boolean test(String name) {
return names.contains(name);
}
return result;
}

static ExecutableElement findMethod(TypeElement declarationType, String name) {
return findMethod(new String[] { name }, e -> true, declarationType, superClassOf(declarationType));
record TypeMatcher(Predicate<TypeName> matcher, String description) implements Predicate<TypeName> {
@Override
public boolean test(TypeName typeName) {
return matcher.test(typeName);
}

@Override
public String toString() {
return description;
}
}

private static TypeElement superClassOf(TypeElement declarationType) {
TypeMirror superclass = declarationType.getSuperclass();
if (superclass instanceof DeclaredType declaredType) {
Element superclassElement = declaredType.asElement();
if (superclassElement instanceof TypeElement) {
return (TypeElement) superclassElement;
}
record ArgumentMatcher(Predicate<List<TypeName>> matcher, String description) implements Predicate<List<TypeName>> {
@Override
public boolean test(List<TypeName> typeName) {
return matcher.test(typeName);
}

@Override
public String toString() {
return description;
}
}

private static boolean isImplementing(Elements elements, TypeName type, TypeName requiredInterface) {
return allInterfacesOf(elements, type).anyMatch(
anInterface -> Objects.equals(anInterface.toString(), requiredInterface.toString())
);
}

private static Stream<TypeName> allInterfacesOf(Elements elements, TypeName type) {
var typeElement = elements.getTypeElement(type.toString());
var superType = Stream.of(typeElement.getSuperclass()).filter(sType -> sType.getKind() != TypeKind.NONE).map(TypeName::get);
var interfaces = typeElement.getInterfaces().stream().map(TypeName::get);
return Stream.concat(
superType.flatMap(sType -> allInterfacesOf(elements, sType)),
interfaces.flatMap(anInterface -> Stream.concat(Stream.of(anInterface), allInterfacesOf(elements, anInterface)))
);
}

private static Stream<TypeElement> typeAndSuperType(TypeElement declarationType) {
if (declarationType.getSuperclass() instanceof DeclaredType declaredType
&& declaredType.asElement() instanceof TypeElement superType) {
return Stream.of(declarationType, superType);
} else {
return Stream.of(declarationType);
}
return null;
}

static ExecutableElement findMethod(TypeElement declarationType, String[] names, Predicate<ExecutableElement> filter) {
Expand All @@ -95,16 +185,6 @@ static ExecutableElement findMethod(String[] names, Predicate<ExecutableElement>
return null;
}

/**
* Returns the arguments of a method after applying a filter.
*/
static VariableElement[] findMethodArguments(ExecutableElement method, Predicate<VariableElement> filter) {
if (method.getParameters().isEmpty()) {
return new VariableElement[0];
}
return method.getParameters().stream().filter(filter).toArray(VariableElement[]::new);
}

/**
* Returns the name of the method used to add {@code valueType} instances
* to vector or block builders.
Expand Down
Loading

0 comments on commit c2e632a

Please sign in to comment.