Avoid asking parent classloader for access class that has not yet been generated, clean up.

closes 
closes 
This commit is contained in:
NathanSweet 2018-06-11 17:18:42 +02:00
parent 54f453484e
commit da80699d0d
5 changed files with 132 additions and 123 deletions
src/com/esotericsoftware/reflectasm
test/com/esotericsoftware/reflectasm

View File

@ -17,6 +17,7 @@ package com.esotericsoftware.reflectasm;
import java.lang.ref.WeakReference;
import java.lang.reflect.Method;
import java.security.ProtectionDomain;
import java.util.HashSet;
import java.util.WeakHashMap;
class AccessClassLoader extends ClassLoader {
@ -29,9 +30,92 @@ class AccessClassLoader extends ClassLoader {
// Fast-path for classes loaded in the same ClassLoader as this class.
static private final ClassLoader selfContextParentClassLoader = getParentClassLoader(AccessClassLoader.class);
static private volatile AccessClassLoader selfContextAccessClassLoader = new AccessClassLoader(selfContextParentClassLoader);
static private volatile Method defineClassMethod;
private final HashSet<String> localClassNames = new HashSet();
private AccessClassLoader (ClassLoader parent) {
super(parent);
}
/** Returns null if the access class has not yet been defined. */
Class loadAccessClass (String name) {
// No need to check the parent class loader if the access class hasn't been defined yet.
if (localClassNames.contains(name)) {
try {
return loadClass(name, false);
} catch (ClassNotFoundException ex) {
throw new RuntimeException(ex); // Should not happen, since we know the class has been defined.
}
}
return null;
}
Class defineAccessClass (String name, byte[] bytes) throws ClassFormatError {
localClassNames.add(name);
return defineClass(name, bytes);
}
protected Class<?> loadClass (String name, boolean resolve) throws ClassNotFoundException {
// These classes come from the classloader that loaded AccessClassLoader.
if (name.equals(FieldAccess.class.getName())) return FieldAccess.class;
if (name.equals(MethodAccess.class.getName())) return MethodAccess.class;
if (name.equals(ConstructorAccess.class.getName())) return ConstructorAccess.class;
if (name.equals(PublicConstructorAccess.class.getName())) return PublicConstructorAccess.class;
// All other classes come from the classloader that loaded the type we are accessing.
return super.loadClass(name, resolve);
}
Class<?> defineClass (String name, byte[] bytes) throws ClassFormatError {
try {
// Attempt to load the access class in the same loader, which makes protected and default access members accessible.
return (Class<?>)getDefineClassMethod().invoke(getParent(),
new Object[] {name, bytes, Integer.valueOf(0), Integer.valueOf(bytes.length), getClass().getProtectionDomain()});
} catch (Exception ignored) {
// continue with the definition in the current loader (won't have access to protected and package-protected members)
}
return defineClass(name, bytes, 0, bytes.length, getClass().getProtectionDomain());
}
// As per JLS, section 5.3,
// "The runtime package of a class or interface is determined by the package name and defining class loader of the class or
// interface."
static boolean areInSameRuntimeClassLoader (Class type1, Class type2) {
if (type1.getPackage() != type2.getPackage()) {
return false;
}
ClassLoader loader1 = type1.getClassLoader();
ClassLoader loader2 = type2.getClassLoader();
ClassLoader systemClassLoader = ClassLoader.getSystemClassLoader();
if (loader1 == null) {
return (loader2 == null || loader2 == systemClassLoader);
}
if (loader2 == null) return loader1 == systemClassLoader;
return loader1 == loader2;
}
static private ClassLoader getParentClassLoader (Class type) {
ClassLoader parent = type.getClassLoader();
if (parent == null) parent = ClassLoader.getSystemClassLoader();
return parent;
}
static private Method getDefineClassMethod () throws Exception {
// DCL on volatile
if (defineClassMethod == null) {
synchronized (accessClassLoaders) {
defineClassMethod = ClassLoader.class.getDeclaredMethod("defineClass",
new Class[] {String.class, byte[].class, int.class, int.class, ProtectionDomain.class});
try {
defineClassMethod.setAccessible(true);
} catch (Exception ignored) {
}
}
}
return defineClassMethod;
}
static AccessClassLoader get (Class type) {
ClassLoader parent = getParentClassLoader(type);
// 1. fast-path:
@ -60,7 +144,7 @@ class AccessClassLoader extends ClassLoader {
}
}
public static void remove (ClassLoader parent) {
static public void remove (ClassLoader parent) {
// 1. fast-path:
if (selfContextParentClassLoader.equals(parent)) {
selfContextAccessClassLoader = null;
@ -72,75 +156,9 @@ class AccessClassLoader extends ClassLoader {
}
}
public static int activeAccessClassLoaders () {
static public int activeAccessClassLoaders () {
int sz = accessClassLoaders.size();
if (selfContextAccessClassLoader != null) sz++;
return sz;
}
private AccessClassLoader (ClassLoader parent) {
super(parent);
}
protected java.lang.Class<?> loadClass (String name, boolean resolve) throws ClassNotFoundException {
// These classes come from the classloader that loaded AccessClassLoader.
if (name.equals(FieldAccess.class.getName())) return FieldAccess.class;
if (name.equals(MethodAccess.class.getName())) return MethodAccess.class;
if (name.equals(ConstructorAccess.class.getName())) return ConstructorAccess.class;
if (name.equals(PublicConstructorAccess.class.getName())) return PublicConstructorAccess.class;
// All other classes come from the classloader that loaded the type we are accessing.
return super.loadClass(name, resolve);
}
Class<?> defineClass (String name, byte[] bytes) throws ClassFormatError {
try {
// Attempt to load the access class in the same loader, which makes protected and default access members accessible.
return (Class<?>)getDefineClassMethod().invoke(getParent(), new Object[] {name, bytes, Integer.valueOf(0), Integer.valueOf(bytes.length),
getClass().getProtectionDomain()});
} catch (Exception ignored) {
// continue with the definition in the current loader (won't have access to protected and package-protected members)
}
return defineClass(name, bytes, 0, bytes.length, getClass().getProtectionDomain());
}
// As per JLS, section 5.3,
// "The runtime package of a class or interface is determined by the package name and defining class loader of the class or interface."
static boolean areInSameRuntimeClassLoader(Class type1, Class type2) {
if (type1.getPackage()!=type2.getPackage()) {
return false;
}
ClassLoader loader1 = type1.getClassLoader();
ClassLoader loader2 = type2.getClassLoader();
ClassLoader systemClassLoader = ClassLoader.getSystemClassLoader();
if (loader1==null) {
return (loader2==null || loader2==systemClassLoader);
}
if (loader2==null) {
return loader1==systemClassLoader;
}
return loader1==loader2;
}
private static ClassLoader getParentClassLoader (Class type) {
ClassLoader parent = type.getClassLoader();
if (parent == null) parent = ClassLoader.getSystemClassLoader();
return parent;
}
private static Method getDefineClassMethod() throws Exception {
// DCL on volatile
if (defineClassMethod==null) {
synchronized(accessClassLoaders) {
defineClassMethod = ClassLoader.class.getDeclaredMethod("defineClass", new Class[] {String.class, byte[].class, int.class,
int.class, ProtectionDomain.class});
try {
defineClassMethod.setAccessible(true);
}
catch (Exception ignored) {
}
}
}
return defineClassMethod;
}
}

View File

@ -16,13 +16,13 @@ package com.esotericsoftware.reflectasm;
import static com.esotericsoftware.asm.Opcodes.*;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import com.esotericsoftware.asm.ClassWriter;
import com.esotericsoftware.asm.MethodVisitor;
public abstract class ConstructorAccess<T> {
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
abstract public class ConstructorAccess<T> {
boolean isNonStaticMemberClass;
public boolean isNonStaticMemberClass () {
@ -48,16 +48,13 @@ public abstract class ConstructorAccess<T> {
String className = type.getName();
String accessClassName = className + "ConstructorAccess";
if (accessClassName.startsWith("java.")) accessClassName = "reflectasm." + accessClassName;
Class accessClass;
AccessClassLoader loader = AccessClassLoader.get(type);
try {
accessClass = loader.loadClass(accessClassName);
} catch (ClassNotFoundException ignored) {
Class accessClass = loader.loadAccessClass(accessClassName);
if (accessClass == null) {
synchronized (loader) {
try {
accessClass = loader.loadClass(accessClassName);
} catch (ClassNotFoundException ignored2) {
accessClass = loader.loadAccessClass(accessClassName);
if (accessClass == null) {
String accessClassNameInternal = accessClassName.replace('.', '/');
String classNameInternal = className.replace('.', '/');
String enclosingClassNameInternal;
@ -80,17 +77,18 @@ public abstract class ConstructorAccess<T> {
constructor = type.getDeclaredConstructor(enclosingType); // Inner classes should have this.
modifiers = constructor.getModifiers();
} catch (Exception ex) {
throw new RuntimeException("Non-static member class cannot be created (missing enclosing class constructor): "
+ type.getName(), ex);
throw new RuntimeException(
"Non-static member class cannot be created (missing enclosing class constructor): " + type.getName(), ex);
}
if (Modifier.isPrivate(modifiers)) {
throw new RuntimeException(
"Non-static member class cannot be created (the enclosing class constructor is private): " + type.getName());
"Non-static member class cannot be created (the enclosing class constructor is private): "
+ type.getName());
}
}
String superclassNameInternal = Modifier.isPublic(modifiers) ?
"com/esotericsoftware/reflectasm/PublicConstructorAccess" :
"com/esotericsoftware/reflectasm/ConstructorAccess";
String superclassNameInternal = Modifier.isPublic(modifiers)
? "com/esotericsoftware/reflectasm/PublicConstructorAccess"
: "com/esotericsoftware/reflectasm/ConstructorAccess";
ClassWriter cw = new ClassWriter(0);
cw.visit(V1_1, ACC_PUBLIC + ACC_SUPER, accessClassNameInternal, null, superclassNameInternal, null);
@ -100,7 +98,7 @@ public abstract class ConstructorAccess<T> {
insertNewInstanceInner(cw, classNameInternal, enclosingClassNameInternal);
cw.visitEnd();
accessClass = loader.defineClass(accessClassName, cw.toByteArray());
accessClass = loader.defineAccessClass(accessClassName, cw.toByteArray());
}
}
}
@ -110,14 +108,13 @@ public abstract class ConstructorAccess<T> {
} catch (Throwable t) {
throw new RuntimeException("Exception constructing constructor access class: " + accessClassName, t);
}
if (!(access instanceof PublicConstructorAccess) && !AccessClassLoader.areInSameRuntimeClassLoader(type, accessClass)) {
if (!(access instanceof PublicConstructorAccess) && !AccessClassLoader.areInSameRuntimeClassLoader(type, accessClass)) {
// Must test this after the try-catch block, whether the class has been loaded as if has been defined.
// Throw a Runtime exception here instead of an IllegalAccessError when invoking newInstance()
throw new RuntimeException(
(!isNonStaticMemberClass ?
"Class cannot be created (the no-arg constructor is protected or package-protected, and its ConstructorAccess could not be defined in the same class loader): " :
"Non-static member class cannot be created (the enclosing class constructor is protected or package-protected, and its ConstructorAccess could not be defined in the same class loader): ")
+ type.getName());
throw new RuntimeException((!isNonStaticMemberClass
? "Class cannot be created (the no-arg constructor is protected or package-protected, and its ConstructorAccess could not be defined in the same class loader): "
: "Non-static member class cannot be created (the enclosing class constructor is protected or package-protected, and its ConstructorAccess could not be defined in the same class loader): ")
+ type.getName());
}
access.isNonStaticMemberClass = isNonStaticMemberClass;
return access;

View File

@ -137,16 +137,13 @@ public abstract class FieldAccess {
String className = type.getName();
String accessClassName = className + "FieldAccess";
if (accessClassName.startsWith("java.")) accessClassName = "reflectasm." + accessClassName;
Class accessClass = null;
AccessClassLoader loader = AccessClassLoader.get(type);
try {
accessClass = loader.loadClass(accessClassName);
} catch (ClassNotFoundException ignored) {
Class accessClass = loader.loadAccessClass(accessClassName);
if (accessClass == null) {
synchronized (loader) {
try {
accessClass = loader.loadClass(accessClassName);
} catch (ClassNotFoundException ignored2) {
accessClass = loader.loadAccessClass(accessClassName);
if (accessClass == null) {
String accessClassNameInternal = accessClassName.replace('.', '/');
String classNameInternal = className.replace('.', '/');
@ -174,7 +171,7 @@ public abstract class FieldAccess {
insertSetPrimitive(cw, classNameInternal, fields, Type.CHAR_TYPE);
insertGetString(cw, classNameInternal, fields);
cw.visitEnd();
accessClass = loader.defineClass(accessClassName, cw.toByteArray());
accessClass = loader.defineAccessClass(accessClassName, cw.toByteArray());
}
}
}

View File

@ -16,17 +16,17 @@ package com.esotericsoftware.reflectasm;
import static com.esotericsoftware.asm.Opcodes.*;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import com.esotericsoftware.asm.ClassWriter;
import com.esotericsoftware.asm.Label;
import com.esotericsoftware.asm.MethodVisitor;
import com.esotericsoftware.asm.Opcodes;
import com.esotericsoftware.asm.Type;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
public abstract class MethodAccess {
private String[] methodNames;
private Class[][] parameterTypes;
@ -78,22 +78,22 @@ public abstract class MethodAccess {
return returnTypes;
}
/** @param type Must not be the Object class, an interface, a primitive type, or void. */
/** Creates a new MethodAccess for the specified type.
* @param type Must not be the Object class, a primitive type, or void. */
static public MethodAccess get (Class type) {
if (type.getSuperclass() == null)
boolean isInterface = type.isInterface();
if (!isInterface && type.getSuperclass() == null)
throw new IllegalArgumentException("The type must not be the Object class, an interface, a primitive type, or void.");
ArrayList<Method> methods = new ArrayList<Method>();
boolean isInterface = type.isInterface();
if (!isInterface) {
Class nextClass = type;
while (nextClass != Object.class) {
addDeclaredMethodsToList(nextClass, methods);
nextClass = nextClass.getSuperclass();
}
} else {
} else
recursiveAddInterfaceMethodsToList(type, methods);
}
int n = methods.size();
String[] methodNames = new String[n];
@ -109,16 +109,13 @@ public abstract class MethodAccess {
String className = type.getName();
String accessClassName = className + "MethodAccess";
if (accessClassName.startsWith("java.")) accessClassName = "reflectasm." + accessClassName;
Class accessClass;
AccessClassLoader loader = AccessClassLoader.get(type);
try {
accessClass = loader.loadClass(accessClassName);
} catch (ClassNotFoundException ignored) {
Class accessClass = loader.loadAccessClass(accessClassName);
if (accessClass == null) {
synchronized (loader) {
try {
accessClass = loader.loadClass(accessClassName);
} catch (ClassNotFoundException ignored2) {
accessClass = loader.loadAccessClass(accessClassName);
if (accessClass == null) {
String accessClassNameInternal = accessClassName.replace('.', '/');
String classNameInternal = className.replace('.', '/');
@ -277,7 +274,7 @@ public abstract class MethodAccess {
}
cw.visitEnd();
byte[] data = cw.toByteArray();
accessClass = loader.defineClass(accessClassName, data);
accessClass = loader.defineAccessClass(accessClassName, data);
}
}
}
@ -292,7 +289,7 @@ public abstract class MethodAccess {
}
}
private static void addDeclaredMethodsToList (Class type, ArrayList<Method> methods) {
static private void addDeclaredMethodsToList (Class type, ArrayList<Method> methods) {
Method[] declaredMethods = type.getDeclaredMethods();
for (int i = 0, n = declaredMethods.length; i < n; i++) {
Method method = declaredMethods[i];
@ -303,10 +300,9 @@ public abstract class MethodAccess {
}
}
private static void recursiveAddInterfaceMethodsToList (Class interfaceType, ArrayList<Method> methods) {
static private void recursiveAddInterfaceMethodsToList (Class interfaceType, ArrayList<Method> methods) {
addDeclaredMethodsToList(interfaceType, methods);
for (Class nextInterface : interfaceType.getInterfaces()) {
for (Class nextInterface : interfaceType.getInterfaces())
recursiveAddInterfaceMethodsToList(nextInterface, methods);
}
}
}

View File

@ -91,6 +91,7 @@ public class MethodAccessTest extends TestCase {
public void testInvokeInterface () {
MethodAccess access = MethodAccess.get(ConcurrentMap.class);
access = MethodAccess.get(ConcurrentMap.class);
ConcurrentHashMap<String, String> someMap = new ConcurrentHashMap<String, String>();
someMap.put("first", "one");
someMap.put("second", "two");