Skip to content

Commit

Permalink
#2025 - Fix regression in AOT reflection metadata generation.
Browse files Browse the repository at this point in the history
Fixed the detection of abstract classes in AOT metadata generation so that Jackson mixin types get detected again.
  • Loading branch information
odrotbohm committed Jul 19, 2023
1 parent 3479797 commit 577117d
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 120 deletions.
65 changes: 57 additions & 8 deletions src/main/java/org/springframework/hateoas/aot/AotUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
package org.springframework.hateoas.aot;

import java.io.IOException;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
Expand All @@ -28,6 +30,7 @@
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.ReflectionHints;
import org.springframework.aot.hint.TypeReference;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.ResolvableType;
Expand Down Expand Up @@ -93,6 +96,20 @@ public static void registerTypeForReflection(Class<?> type, ReflectionHints refl
SEEN_TYPES.add(type);
}

public static void registerTypesForReflection(String packageName, ReflectionHints reflection, TypeFilter... filters) {

// Register RepresentationModel types for full reflection
var provider = AotUtils.getScanner(packageName, filters);

LOGGER.info("Registering Spring HATEOAS types in {} for reflection.", packageName);

provider.findClasses()
.sorted(Comparator.comparing(TypeReference::getName))
.peek(type -> LOGGER.debug("> {}", type.getName()))
.forEach(reference -> reflection.registerType(reference, //
MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS));
}

/**
* Extracts the generics from the given model type if the given {@link ResolvableType} is assignable.
*
Expand Down Expand Up @@ -129,15 +146,28 @@ private static Optional<Class<?>> extractGenerics(Class<?> modelType, Resolvable

public static FullTypeScanner getScanner(String packageName, TypeFilter... includeFilters) {

var provider = new ClassPathScanningCandidateComponentProvider(false);
var provider = new ClassPathScanningCandidateComponentProvider(false) {

@Override
protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
return super.isCandidateComponent(beanDefinition) || beanDefinition.getMetadata().isAbstract();
}
};

var filters = new ArrayList<TypeFilter>();
filters.add(new EnforcedPackageFilter(packageName));
filters.add(new AssignableTypeFilter(Object.class));

if (includeFilters.length == 0) {
provider.addIncludeFilter(new AssignableTypeFilter(Object.class));
} else {
Arrays.stream(includeFilters).forEach(provider::addIncludeFilter);
provider.addIncludeFilter(all(filters));
}

provider.addExcludeFilter(new EnforcedPackageFilter(packageName));
for (TypeFilter filter : includeFilters) {

var includeFilterComponents = new ArrayList<>(filters);
includeFilterComponents.add(filter);
provider.addIncludeFilter(all(includeFilterComponents));
}

return () -> provider.findCandidateComponents(packageName).stream()
.map(BeanDefinition::getBeanClassName)
Expand All @@ -150,7 +180,7 @@ public static FullTypeScanner getScanner(String packageName, TypeFilter... inclu
*
* @author Oliver Drotbohm
*/
private static class EnforcedPackageFilter implements TypeFilter {
static class EnforcedPackageFilter implements TypeFilter {

private final String referencePackage;

Expand All @@ -165,11 +195,30 @@ public EnforcedPackageFilter(String referencePackage) {
@Override
public boolean match(MetadataReader metadataReader, MetadataReaderFactory metadataReaderFactory)
throws IOException {
return !referencePackage
return referencePackage
.equals(ClassUtils.getPackageName(metadataReader.getClassMetadata().getClassName()));
}
}

private static TypeFilter all(Collection<TypeFilter> filters) {

return new TypeFilter() {

@Override
public boolean match(MetadataReader metadataReader, MetadataReaderFactory metadataReaderFactory)
throws IOException {

for (TypeFilter filter : filters) {
if (!filter.match(metadataReader, metadataReaderFactory)) {
return false;
}
}

return true;
}
};
}

static interface FullTypeScanner {

abstract Stream<TypeReference> findClasses();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package org.springframework.hateoas.aot;

import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.hateoas.RepresentationModel;
Expand All @@ -34,12 +33,9 @@ class HateoasTypesRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {

var packageName = RepresentationModel.class.getPackageName();
var reflection = hints.reflection();

AotUtils.getScanner(RepresentationModel.class.getPackageName()) //
.findClasses() //
.forEach(it -> reflection.registerType(it, //
MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, //
MemberCategory.INVOKE_DECLARED_METHODS));
AotUtils.registerTypesForReflection(packageName, reflection);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,18 @@
*/
package org.springframework.hateoas.aot;

import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.TypeReference;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor;
import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.MergedAnnotation;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.core.type.filter.TypeFilter;
import org.springframework.hateoas.aot.AotUtils.FullTypeScanner;
import org.springframework.hateoas.config.EnableHypermediaSupport;
import org.springframework.hateoas.config.EnableHypermediaSupport.HypermediaType;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -80,8 +68,6 @@ public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registe

static class MediaTypeReflectionAotContribution implements BeanRegistrationAotContribution {

private static final Logger LOGGER = LoggerFactory.getLogger(MediaTypeReflectionAotContribution.class);

private final List<String> mediaTypePackage;
private final Set<String> packagesSeen;

Expand All @@ -105,8 +91,6 @@ public MediaTypeReflectionAotContribution(List<String> mediaTypePackage) {
@Override
public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) {

var reflection = generationContext.getRuntimeHints().reflection();

mediaTypePackage.forEach(it -> {

if (packagesSeen.contains(it)) {
Expand All @@ -115,97 +99,9 @@ public void applyTo(GenerationContext generationContext, BeanRegistrationCode be

packagesSeen.add(it);

// Register RepresentationModel types for full reflection
FullTypeScanner provider = AotUtils.getScanner(it, //
new JacksonAnnotationPresentFilter(), //
new JacksonSuperTypeFilter());

LOGGER.info("Registering Spring HATEOAS types in {} for reflection.", it);

provider.findClasses()
.sorted(Comparator.comparing(TypeReference::getName))
.peek(type -> LOGGER.debug("> {}", type.getName()))
.forEach(reference -> reflection.registerType(reference, //
MemberCategory.INVOKE_DECLARED_CONSTRUCTORS, MemberCategory.INVOKE_DECLARED_METHODS));
new HypermediaTypesRuntimeHints(it) //
.registerHints(generationContext.getRuntimeHints(), getClass().getClassLoader());
});
}
}

static abstract class TraversingTypeFilter implements TypeFilter {

/*
* (non-Javadoc)
* @see org.springframework.core.type.filter.TypeFilter#match(org.springframework.core.type.classreading.MetadataReader, org.springframework.core.type.classreading.MetadataReaderFactory)
*/
@Override
public boolean match(MetadataReader metadataReader, MetadataReaderFactory metadataReaderFactory)
throws IOException {

if (doMatch(metadataReader, metadataReaderFactory)) {
return true;
}

var classMetadata = metadataReader.getClassMetadata();

String superClassName = classMetadata.getSuperClassName();

if (superClassName != null && !superClassName.startsWith("java")
&& match(metadataReaderFactory.getMetadataReader(superClassName), metadataReaderFactory)) {
return true;
}

for (String names : classMetadata.getInterfaceNames()) {

MetadataReader reader = metadataReaderFactory.getMetadataReader(names);

if (match(reader, metadataReaderFactory)) {
return true;
}
}

return false;
}

protected abstract boolean doMatch(MetadataReader reader, MetadataReaderFactory factory);
}

static class JacksonAnnotationPresentFilter extends TraversingTypeFilter {

private static final Predicate<String> IS_JACKSON_ANNOTATION = it -> it.startsWith("com.fasterxml.jackson");

/*
* (non-Javadoc)
* @see org.springframework.hateoas.aot.HateoasRuntimeHints.TraversingTypeFilter#doMatch(org.springframework.core.type.classreading.MetadataReader, org.springframework.core.type.classreading.MetadataReaderFactory)
*/
@Override
protected boolean doMatch(MetadataReader reader, MetadataReaderFactory factory) {

var annotationMetadata = reader.getAnnotationMetadata();

// Type annotations
return annotationMetadata
.getAnnotationTypes()
.stream()
.anyMatch(IS_JACKSON_ANNOTATION)

// Method annotations
|| annotationMetadata.getDeclaredMethods().stream()
.flatMap(it -> it.getAnnotations().stream())
.map(MergedAnnotation::getType)
.map(Class::getName)
.anyMatch(IS_JACKSON_ANNOTATION);
}
}

static class JacksonSuperTypeFilter extends TraversingTypeFilter {

/*
* (non-Javadoc)
* @see org.springframework.hateoas.aot.HateoasRuntimeHints.TraversingTypeFilter#doMatch(org.springframework.core.type.classreading.MetadataReader, org.springframework.core.type.classreading.MetadataReaderFactory)
*/
@Override
protected boolean doMatch(MetadataReader reader, MetadataReaderFactory factory) {
return reader.getClassMetadata().getClassName().startsWith("com.fasterxml.jackson");
}
}
}
Loading

0 comments on commit 577117d

Please sign in to comment.