Skip to content

Commit

Permalink
Reflection bugfixes for missing registration errors
Browse files Browse the repository at this point in the history
  • Loading branch information
loicottet committed Aug 14, 2023
1 parent 63b91c7 commit 11f7735
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
import com.oracle.svm.agent.tracing.core.Tracer;
import com.oracle.svm.configure.trace.AccessAdvisor;
import com.oracle.svm.core.c.function.CEntryPointOptions;
import com.oracle.svm.core.jni.JNIObjectHandles;
import com.oracle.svm.core.jni.headers.JNIEnvironment;
import com.oracle.svm.core.jni.headers.JNIFieldId;
import com.oracle.svm.core.jni.headers.JNIMethodId;
Expand Down Expand Up @@ -481,8 +480,7 @@ private static boolean handleInvokeMethod(JNIEnvironment jni, JNIObjectHandle th
*/
if (isInvoke && isClassNewInstance(jni, declaring, name)) {
JNIObjectHandle clazz = getObjectArgument(thread, 1);
JNIMethodId result = newInstanceMethodID(jni, clazz);
traceReflectBreakpoint(jni, clazz, nullHandle(), callerClass, "newInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull());
traceReflectBreakpoint(jni, clazz, nullHandle(), callerClass, "newInstance", clazz.notEqual(nullHandle()), state.getFullStackTraceOrNull());
}
return true;
}
Expand Down Expand Up @@ -529,26 +527,10 @@ private static boolean handleInvokeConstructor(JNIEnvironment jni, @SuppressWarn
private static boolean newInstance(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle self = getReceiver(thread);
JNIMethodId result = newInstanceMethodID(jni, self);
traceReflectBreakpoint(jni, self, nullHandle(), callerClass, bp.specification.methodName, result.notEqual(nullHandle()), state.getFullStackTraceOrNull());
traceReflectBreakpoint(jni, self, nullHandle(), callerClass, bp.specification.methodName, self.notEqual(nullHandle()), state.getFullStackTraceOrNull());
return true;
}

private static JNIMethodId newInstanceMethodID(JNIEnvironment jni, JNIObjectHandle clazz) {
JNIMethodId result = nullPointer();
String name = "<init>";
String signature = "()V";
if (clazz.notEqual(nullHandle())) {
try (CCharPointerHolder ctorName = toCString(name); CCharPointerHolder ctorSignature = toCString(signature)) {
result = jniFunctions().getGetMethodID().invoke(jni, clazz, ctorName.get(), ctorSignature.get());
}
if (clearException(jni)) {
result = nullPointer();
}
}
return result;
}

private static boolean newArrayInstance(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
JNIValue args = StackValue.get(2, JNIValue.class);
args.addressOf(0).setObject(getObjectArgument(thread, 0));
Expand Down Expand Up @@ -799,14 +781,9 @@ private static boolean loadClass(JNIEnvironment jni, JNIObjectHandle thread, Bre
observedExplicitLoadClassCallSites.put(location, Boolean.TRUE);
}
}
JNIObjectHandle self = getReceiver(thread);
JNIObjectHandle name = getObjectArgument(thread, 1);
String className = fromJniString(jni, name);
JNIObjectHandle clazz = Support.callObjectMethodL(jni, self, bp.method, name);
if (clearException(jni)) {
clazz = nullHandle();
}
traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, clazz.notEqual(nullHandle()), state.getFullStackTraceOrNull(), className);
traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, bp.specification.methodName, className != null, state.getFullStackTraceOrNull(), className);
return true;
}

Expand Down Expand Up @@ -858,70 +835,53 @@ private static boolean isLoadClassInvocation(JNIObjectHandle clazz, JNIMethodId
}
}

private static boolean findMethodHandle(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
private static boolean findMethodHandle(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getReceiver(thread);
JNIObjectHandle declaringClass = getObjectArgument(thread, 1);
JNIObjectHandle methodName = getObjectArgument(thread, 2);
JNIObjectHandle methodType = getObjectArgument(thread, 3);

JNIObjectHandle result = Support.callObjectMethodLLL(jni, lookup, bp.method, declaringClass, methodName, methodType);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException);

return methodMethodHandle(jni, declaringClass, callerClass, methodName, getParamTypes(jni, methodType), result, state.getFullStackTraceOrNull());
return methodMethodHandle(jni, declaringClass, callerClass, methodName, getParamTypes(jni, methodType), state.getFullStackTraceOrNull());
}

private static boolean findSpecialHandle(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
private static boolean findSpecialHandle(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getReceiver(thread);
JNIObjectHandle declaringClass = getObjectArgument(thread, 1);
JNIObjectHandle methodName = getObjectArgument(thread, 2);
JNIObjectHandle methodType = getObjectArgument(thread, 3);
JNIObjectHandle specialCaller = getObjectArgument(thread, 4);

JNIObjectHandle result = Support.callObjectMethodLLLL(jni, lookup, bp.method, declaringClass, methodName, methodType, specialCaller);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException);

return methodMethodHandle(jni, declaringClass, callerClass, methodName, getParamTypes(jni, methodType), result, state.getFullStackTraceOrNull());
return methodMethodHandle(jni, declaringClass, callerClass, methodName, getParamTypes(jni, methodType), state.getFullStackTraceOrNull());
}

private static boolean bindHandle(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
private static boolean bindHandle(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getReceiver(thread);
JNIObjectHandle receiver = getObjectArgument(thread, 1);
JNIObjectHandle methodName = getObjectArgument(thread, 2);
JNIObjectHandle methodType = getObjectArgument(thread, 3);

JNIObjectHandle result = Support.callObjectMethodLLL(jni, lookup, bp.method, receiver, methodName, methodType);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException);

JNIObjectHandle declaringClass = Support.callObjectMethod(jni, receiver, agent.handles().javaLangObjectGetClass);
if (clearException(jni)) {
declaringClass = nullHandle();
}

return methodMethodHandle(jni, declaringClass, callerClass, methodName, getParamTypes(jni, methodType), result, state.getFullStackTraceOrNull());
return methodMethodHandle(jni, declaringClass, callerClass, methodName, getParamTypes(jni, methodType), state.getFullStackTraceOrNull());
}

private static boolean methodMethodHandle(JNIEnvironment jni, JNIObjectHandle declaringClass, JNIObjectHandle callerClass, JNIObjectHandle nameHandle, JNIObjectHandle paramTypesHandle,
JNIObjectHandle result, JNIMethodId[] stackTrace) {
JNIMethodId[] stackTrace) {
String name = fromJniString(jni, nameHandle);
Object paramTypes = getClassArrayNames(jni, paramTypesHandle);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findMethodHandle", result.notEqual(nullHandle()), stackTrace, name, paramTypes);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findMethodHandle", declaringClass.notEqual(nullHandle()) && name != null, stackTrace, name, paramTypes);
return true;
}

private static boolean findConstructorHandle(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
private static boolean findConstructorHandle(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getReceiver(thread);
JNIObjectHandle declaringClass = getObjectArgument(thread, 1);
JNIObjectHandle methodType = getObjectArgument(thread, 2);

JNIObjectHandle result = Support.callObjectMethodLL(jni, lookup, bp.method, declaringClass, methodType);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException);

Object paramTypes = getClassArrayNames(jni, getParamTypes(jni, methodType));
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findConstructorHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), paramTypes);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findConstructorHandle", declaringClass.notEqual(nullHandle()), state.getFullStackTraceOrNull(), paramTypes);
return true;
}

Expand All @@ -933,42 +893,29 @@ private static JNIObjectHandle getParamTypes(JNIEnvironment jni, JNIObjectHandle
return paramTypesHandle;
}

private static boolean findFieldHandle(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
private static boolean findFieldHandle(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getReceiver(thread);
JNIObjectHandle declaringClass = getObjectArgument(thread, 1);
JNIObjectHandle fieldName = getObjectArgument(thread, 2);
JNIObjectHandle fieldType = getObjectArgument(thread, 3);

JNIObjectHandle result = Support.callObjectMethodLLL(jni, lookup, bp.method, declaringClass, fieldName, fieldType);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException);

String name = fromJniString(jni, fieldName);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", declaringClass.notEqual(nullHandle()) && name != null, state.getFullStackTraceOrNull(), name);
return true;
}

private static boolean findClass(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getReceiver(thread);
JNIObjectHandle className = getObjectArgument(thread, 1);

JNIObjectHandle result = Support.callObjectMethodL(jni, lookup, bp.method, className);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException);

String name = fromJniString(jni, className);
traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, "findClass", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name);
traceReflectBreakpoint(jni, bp.clazz, nullHandle(), callerClass, "findClass", name != null, state.getFullStackTraceOrNull(), name);
return true;
}

private static boolean unreflectField(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
private static boolean unreflectField(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getReceiver(thread);
JNIObjectHandle field = getObjectArgument(thread, 1);

JNIObjectHandle result = Support.callObjectMethodL(jni, lookup, bp.method, field);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessException);

JNIObjectHandle declaringClass = Support.callObjectMethod(jni, field, agent.handles().javaLangReflectMemberGetDeclaringClass);
if (clearException(jni)) {
declaringClass = nullHandle();
Expand All @@ -980,46 +927,34 @@ private static boolean unreflectField(JNIEnvironment jni, JNIObjectHandle thread
}

String fieldName = fromJniString(jni, fieldNameHandle);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "unreflectField", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), fieldName);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "unreflectField", declaringClass.notEqual(nullHandle()) && fieldName != null, state.getFullStackTraceOrNull(),
fieldName);
return true;
}

private static boolean asInterfaceInstance(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state) {
private static boolean asInterfaceInstance(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle intfc = getObjectArgument(thread, 0);
JNIObjectHandle methodHandle = getObjectArgument(thread, 1);

JNIObjectHandle result = Support.callStaticObjectMethodLL(jni, bp.clazz, bp.method, intfc, methodHandle);
result = shouldIncludeMethod(jni, result, agent.handles().javaLangInvokeWrongMethodTypeException, agent.handles().javaLangIllegalArgumentException);

JNIObjectHandle intfcNameHandle = Support.callObjectMethod(jni, intfc, agent.handles().javaLangClassGetName);
if (clearException(jni)) {
intfcNameHandle = nullHandle();
}
String intfcName = fromJniString(jni, intfcNameHandle);
traceReflectBreakpoint(jni, intfc, nullHandle(), callerClass, "asInterfaceInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull());
traceReflectBreakpoint(jni, intfc, nullHandle(), callerClass, "asInterfaceInstance", intfcName != null, state.getFullStackTraceOrNull());
String[] intfcNames = new String[]{intfcName};
traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "newMethodHandleProxyInstance", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), (Object) intfcNames);
traceReflectBreakpoint(jni, nullHandle(), nullHandle(), callerClass, "newMethodHandleProxyInstance", intfcName != null, state.getFullStackTraceOrNull(), (Object) intfcNames);
return true;
}

private static boolean constantBootstrapGetStaticFinal(JNIEnvironment jni, JNIObjectHandle thread, Breakpoint bp, InterceptedState state, boolean hasDeclaringClass) {
private static boolean constantBootstrapGetStaticFinal(JNIEnvironment jni, JNIObjectHandle thread, @SuppressWarnings("unused") Breakpoint bp, InterceptedState state, boolean hasDeclaringClass) {
JNIObjectHandle callerClass = state.getDirectCallerClass();
JNIObjectHandle lookup = getObjectArgument(thread, 0);
JNIObjectHandle fieldName = getObjectArgument(thread, 1);
JNIObjectHandle type = getObjectArgument(thread, 2);
JNIObjectHandle declaringClass = hasDeclaringClass ? getObjectArgument(thread, 3) : type;

JNIObjectHandle result;
if (hasDeclaringClass) {
result = Support.callStaticObjectMethodLLLL(jni, bp.clazz, bp.method, lookup, fieldName, type, declaringClass);
} else {
result = Support.callStaticObjectMethodLLL(jni, bp.clazz, bp.method, lookup, fieldName, type);
}
result = shouldIncludeMethod(jni, result, agent.handles().javaLangIllegalAccessError);

String name = fromJniString(jni, fieldName);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", result.notEqual(nullHandle()), state.getFullStackTraceOrNull(), name);
traceReflectBreakpoint(jni, declaringClass, nullHandle(), callerClass, "findFieldHandle", declaringClass.notEqual(nullHandle()) && name != null, state.getFullStackTraceOrNull(), name);
return true;
}

Expand Down Expand Up @@ -1058,24 +993,6 @@ private static boolean methodTypeFromDescriptor(JNIEnvironment jni, JNIObjectHan
return true;
}

private static JNIObjectHandle shouldIncludeMethod(JNIEnvironment jni, JNIObjectHandle result, JNIObjectHandle... acceptedExceptions) {
JNIObjectHandle exception = handleException(jni, true);
if (exception.notEqual(nullHandle())) {
for (JNIObjectHandle acceptedException : acceptedExceptions) {
if (jniFunctions().getIsInstanceOf().invoke(jni, exception, acceptedException)) {
/*
* We include methods if the lookup returned an IllegalAccessException or a
* WrongMethodTypeException to make sure the right exception is thrown at
* runtime, instead of a NoSuchMethodException.
*/
return JNIObjectHandles.createLocal(Boolean.TRUE);
}
}
return nullHandle();
}
return result;
}

/**
* We have to find a class that captures a lambda function so it can be registered by the agent.
* We have to get a SerializedLambda instance first. After that we get a lambda capturing class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,13 @@ public static void registerClass(Class<?> clazz) {
return; // must be defined at runtime before it can be looked up
}
String name = clazz.getName();
if (!singleton().knownClasses.containsKey(name) || !(singleton().knownClasses.get(name) instanceof Throwable)) {
/*
* If the class has already been seen as throwing an error, we don't overwrite this
* error
*/
VMError.guarantee(!singleton().knownClasses.containsKey(name) || singleton().knownClasses.get(name) == clazz);
singleton().knownClasses.put(name, clazz);
}
Object currentValue = singleton().knownClasses.get(name);
VMError.guarantee(currentValue == null || currentValue == clazz || currentValue instanceof Throwable,
"Invalid Class.forName value for %s: %s", name, currentValue);
/*
* If the class has already been seen as throwing an error, we don't overwrite this error
*/
singleton().knownClasses.putIfAbsent(name, clazz);
}

@Platforms(Platform.HOSTED_ONLY.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ public static void forField(Class<?> declaringClass, String fieldName) {

public static void forMethod(Class<?> declaringClass, String methodName, Class<?>[] paramTypes) {
StringJoiner paramTypeNames = new StringJoiner(", ", "(", ")");
for (Class<?> paramType : paramTypes) {
paramTypeNames.add(paramType.getTypeName());
if (paramTypes != null) {
for (Class<?> paramType : paramTypes) {
paramTypeNames.add(paramType.getTypeName());
}
}
MissingReflectionRegistrationError exception = new MissingReflectionRegistrationError(errorMessage("access method",
declaringClass.getTypeName() + "#" + methodName + paramTypeNames),
Expand Down
Loading

0 comments on commit 11f7735

Please sign in to comment.