Merge pull request #21 from serverperformance/master

Fixes in AccessClassLoader and MethodAccess
This commit is contained in:
Nathan Sweet 2014-01-26 05:24:44 -08:00
commit d918153ffe
6 changed files with 390 additions and 77 deletions

View File

@ -1,35 +1,69 @@
package com.esotericsoftware.reflectasm;
import java.lang.ref.WeakReference;
import java.lang.reflect.Method;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.WeakHashMap;
class AccessClassLoader extends ClassLoader {
static private final ArrayList<AccessClassLoader> accessClassLoaders = new ArrayList();
// Weak-references to ClassLoaders, to avoid PermGen memory leaks for example
// in AppServers/WebContainters if the reflectasm framework (including this class)
// is loaded outside the deployed applications (WAR/EAR) using ReflectASM/Kryo
// (exts, user classpath, etc).
//
// The key is the parent ClassLoader and the value is the AccessClassLoader
// Both are weak-referenced in the HashTable.
static private final WeakHashMap<ClassLoader, WeakReference<AccessClassLoader>> accessClassLoaders = new WeakHashMap<ClassLoader, WeakReference<AccessClassLoader>>();
// Fast-path for classes loaded in the same ClassLoader than this Class
static private final ClassLoader selfContextParentClassLoader = getParentClassLoader(AccessClassLoader.class);
static private volatile AccessClassLoader selfContextAccessClassLoader = new AccessClassLoader(selfContextParentClassLoader);
static AccessClassLoader get (Class type) {
ClassLoader parent = type.getClassLoader();
synchronized (accessClassLoaders) {
for (int i = 0, n = accessClassLoaders.size(); i < n; i++) {
AccessClassLoader accessClassLoader = accessClassLoaders.get(i);
if (accessClassLoader.getParent() == parent) return accessClassLoader;
ClassLoader parent = getParentClassLoader(type);
// 1. fast-path:
if (selfContextParentClassLoader.equals(parent)) {
if (selfContextAccessClassLoader==null) {
// DCL with volatile semantics
synchronized (accessClassLoaders) {
if (selfContextAccessClassLoader==null)
selfContextAccessClassLoader = new AccessClassLoader(selfContextParentClassLoader);
}
}
return selfContextAccessClassLoader;
}
// 2. normal search:
synchronized (accessClassLoaders) {
WeakReference<AccessClassLoader> ref = accessClassLoaders.get(parent);
if (ref!=null) {
AccessClassLoader accessClassLoader = ref.get();
if (accessClassLoader!=null) return accessClassLoader;
else accessClassLoaders.remove(parent); // the value has been GC-reclaimed, but still not the key (defensive sanity)
}
if(parent == null) parent = ClassLoader.getSystemClassLoader();
AccessClassLoader accessClassLoader = new AccessClassLoader(parent);
accessClassLoaders.add(accessClassLoader);
accessClassLoaders.put(parent, new WeakReference<AccessClassLoader>(accessClassLoader));
return accessClassLoader;
}
}
static void remove (ClassLoader parent) {
synchronized (accessClassLoaders) {
for (int i = accessClassLoaders.size() - 1; i >= 0; i--) {
AccessClassLoader accessClassLoader = accessClassLoaders.get(i);
if (accessClassLoader.getParent() == parent) accessClassLoaders.remove(i);
public static void remove (ClassLoader parent) {
// 1. fast-path:
if (selfContextParentClassLoader.equals(parent)) {
selfContextAccessClassLoader = null;
}
else {
// 2. normal search:
synchronized (accessClassLoaders) {
accessClassLoaders.remove(parent);
}
}
}
public static int activeAccessClassLoaders() {
int sz = accessClassLoaders.size();
if (selfContextAccessClassLoader!=null) sz++;
return sz;
}
private AccessClassLoader (ClassLoader parent) {
super(parent);
@ -49,11 +83,17 @@ class AccessClassLoader extends ClassLoader {
// Attempt to load the access class in the same loader, which makes protected and default access members accessible.
Method method = ClassLoader.class.getDeclaredMethod("defineClass", new Class[] {String.class, byte[].class, int.class,
int.class, ProtectionDomain.class});
method.setAccessible(true);
if (!method.isAccessible()) method.setAccessible(true);
return (Class)method.invoke(getParent(), new Object[] {name, bytes, Integer.valueOf(0), Integer.valueOf(bytes.length),
getClass().getProtectionDomain()});
} catch (Exception ignored) {
}
return defineClass(name, bytes, 0, bytes.length, getClass().getProtectionDomain());
}
private static ClassLoader getParentClassLoader(Class type) {
ClassLoader parent = type.getClassLoader();
if (parent == null) parent = ClassLoader.getSystemClassLoader();
return parent;
}
}

View File

@ -1,6 +1,7 @@
package com.esotericsoftware.reflectasm;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import org.objectweb.asm.ClassWriter;
@ -46,20 +47,29 @@ public abstract class ConstructorAccess<T> {
String classNameInternal = className.replace('.', '/');
String enclosingClassNameInternal;
boolean isPrivate = false;
if (!isNonStaticMemberClass) {
enclosingClassNameInternal = null;
try {
type.getDeclaredConstructor((Class[])null);
Constructor<T> constructor = type.getDeclaredConstructor((Class[])null);
isPrivate = Modifier.isPrivate(constructor.getModifiers());
} catch (Exception ex) {
throw new RuntimeException("Class cannot be created (missing no-arg constructor): " + type.getName());
throw new RuntimeException("Class cannot be created (missing no-arg constructor): " + type.getName(), ex);
}
if (isPrivate) {
throw new RuntimeException("Class cannot be created (the no-arg constructor is private): " + type.getName());
}
} else {
enclosingClassNameInternal = enclosingType.getName().replace('.', '/');
try {
type.getDeclaredConstructor(enclosingType); // Inner classes should have this.
Constructor<T> constructor = type.getDeclaredConstructor(enclosingType); // Inner classes should have this.
isPrivate = Modifier.isPrivate(constructor.getModifiers());
} catch (Exception ex) {
throw new RuntimeException("Non-static member class cannot be created (missing enclosing class constructor): "
+ type.getName());
+ type.getName(), ex);
}
if (isPrivate) {
throw new RuntimeException("Non-static member class cannot be created (the enclosing class constructor is private): " + type.getName());
}
}

View File

@ -17,12 +17,18 @@ import static org.objectweb.asm.Opcodes.*;
public abstract class MethodAccess {
private String[] methodNames;
private Class[][] parameterTypes;
private Class[] returnTypes;
abstract public Object invoke (Object object, int methodIndex, Object... args);
/** Invokes the first method with the specified name. */
/** Invokes the method with the specified name and the specified param types. */
public Object invoke (Object object, String methodName, Class[] paramTypes, Object... args) {
return invoke(object, getIndex(methodName, paramTypes), args);
}
/** Invokes the first method with the specified name and the specified number of arguments. */
public Object invoke (Object object, String methodName, Object... args) {
return invoke(object, getIndex(methodName), args);
return invoke(object, getIndex(methodName, args==null ? 0 : args.length), args);
}
/** Returns the index of the first method with the specified name. */
@ -32,10 +38,18 @@ public abstract class MethodAccess {
throw new IllegalArgumentException("Unable to find public method: " + methodName);
}
/** Returns the index of the first method with the specified name and param types. */
public int getIndex (String methodName, Class... paramTypes) {
for (int i = 0, n = methodNames.length; i < n; i++)
if (methodNames[i].equals(methodName) && Arrays.equals(paramTypes, parameterTypes[i])) return i;
throw new IllegalArgumentException("Unable to find public method: " + methodName + " " + Arrays.toString(parameterTypes));
throw new IllegalArgumentException("Unable to find public method: " + methodName + " " + Arrays.toString(paramTypes));
}
/** Returns the index of the first method with the specified name and the specified number of arguments. */
public int getIndex (String methodName, int paramsCount) {
for (int i = 0, n = methodNames.length; i < n; i++)
if (methodNames[i].equals(methodName) && parameterTypes[i].length==paramsCount) return i;
throw new IllegalArgumentException("Unable to find public method: " + methodName + " with " + paramsCount + " params.");
}
public String[] getMethodNames () {
@ -45,34 +59,40 @@ public abstract class MethodAccess {
public Class[][] getParameterTypes () {
return parameterTypes;
}
public Class[] getReturnTypes () {
return returnTypes;
}
static public MethodAccess get (Class type) {
ArrayList<Method> methods = new ArrayList();
Class nextClass = type;
while (nextClass != Object.class) {
Method[] declaredMethods = nextClass.getDeclaredMethods();
for (int i = 0, n = declaredMethods.length; i < n; i++) {
Method method = declaredMethods[i];
int modifiers = method.getModifiers();
if (Modifier.isStatic(modifiers)) continue;
if (Modifier.isPrivate(modifiers)) continue;
methods.add(method);
ArrayList<Method> methods = new ArrayList<Method>();
boolean isInterface = type.isInterface();
if (!isInterface) {
Class nextClass = type;
while (nextClass != Object.class) {
addDeclaredMethodsToList(nextClass, methods);
nextClass = nextClass.getSuperclass();
}
nextClass = nextClass.getSuperclass();
}
else {
recursiveAddInterfaceMethodsToList(type, methods);
}
Class[][] parameterTypes = new Class[methods.size()][];
String[] methodNames = new String[methods.size()];
for (int i = 0, n = methodNames.length; i < n; i++) {
int n = methods.size();
String[] methodNames = new String[n];
Class[][] parameterTypes = new Class[n][];
Class[] returnTypes = new Class[n];
for (int i = 0; i < n; i++) {
Method method = methods.get(i);
methodNames[i] = method.getName();
parameterTypes[i] = method.getParameterTypes();
returnTypes[i] = method.getReturnType();
}
String className = type.getName();
String accessClassName = className + "MethodAccess";
if (accessClassName.startsWith("java.")) accessClassName = "reflectasm." + accessClassName;
Class accessClass = null;
Class accessClass;
AccessClassLoader loader = AccessClassLoader.get(type);
synchronized (loader) {
@ -106,14 +126,14 @@ public abstract class MethodAccess {
mv.visitVarInsn(ASTORE, 4);
mv.visitVarInsn(ILOAD, 2);
Label[] labels = new Label[methods.size()];
for (int i = 0, n = labels.length; i < n; i++)
Label[] labels = new Label[n];
for (int i = 0; i < n; i++)
labels[i] = new Label();
Label defaultLabel = new Label();
mv.visitTableSwitchInsn(0, labels.length - 1, defaultLabel, labels);
StringBuilder buffer = new StringBuilder(128);
for (int i = 0, n = labels.length; i < n; i++) {
for (int i = 0; i < n; i++) {
mv.visitLabel(labels[i]);
if (i == 0)
mv.visitFrame(Opcodes.F_APPEND, 1, new Object[] {classNameInternal}, 0, null);
@ -124,8 +144,9 @@ public abstract class MethodAccess {
buffer.setLength(0);
buffer.append('(');
Method method = methods.get(i);
Class[] paramTypes = method.getParameterTypes();
String methodName = methodNames[i];
Class[] paramTypes = parameterTypes[i];
Class returnType = returnTypes[i];
for (int paramIndex = 0; paramIndex < paramTypes.length; paramIndex++) {
mv.visitVarInsn(ALOAD, 3);
mv.visitIntInsn(BIPUSH, paramIndex);
@ -175,10 +196,10 @@ public abstract class MethodAccess {
}
buffer.append(')');
buffer.append(Type.getDescriptor(method.getReturnType()));
mv.visitMethodInsn(INVOKEVIRTUAL, classNameInternal, method.getName(), buffer.toString());
buffer.append(Type.getDescriptor(returnType));
mv.visitMethodInsn(isInterface ? INVOKEINTERFACE : INVOKEVIRTUAL, classNameInternal, methodName, buffer.toString());
switch (Type.getType(method.getReturnType()).getSort()) {
switch (Type.getType(returnType).getSort()) {
case Type.VOID:
mv.visitInsn(ACONST_NULL);
break;
@ -237,9 +258,28 @@ public abstract class MethodAccess {
MethodAccess access = (MethodAccess)accessClass.newInstance();
access.methodNames = methodNames;
access.parameterTypes = parameterTypes;
access.returnTypes = returnTypes;
return access;
} catch (Exception ex) {
throw new RuntimeException("Error constructing method access class: " + accessClassName, ex);
}
}
private static void addDeclaredMethodsToList(Class type, ArrayList<Method> methods) {
Method[] declaredMethods = type.getDeclaredMethods();
for (int i = 0, n = declaredMethods.length; i < n; i++) {
Method method = declaredMethods[i];
int modifiers = method.getModifiers();
if (Modifier.isStatic(modifiers)) continue;
if (Modifier.isPrivate(modifiers)) continue;
methods.add(method);
}
}
private static void recursiveAddInterfaceMethodsToList(Class interfaceType, ArrayList<Method> methods) {
addDeclaredMethodsToList(interfaceType, methods);
for (Class nextInterface : interfaceType.getInterfaces()) {
recursiveAddInterfaceMethodsToList(nextInterface, methods);
}
}
}

View File

@ -4,42 +4,16 @@ package com.esotericsoftware.reflectasm;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.assertTrue;
import junit.framework.TestCase;
public class ClassLoaderTest extends TestCase {
public void testDifferentClassloaders () throws Exception {
// This classloader can see only the Test class and core Java classes.
ClassLoader testClassLoader = new ClassLoader() {
protected synchronized Class<?> loadClass (String name, boolean resolve) throws ClassNotFoundException {
Class c = findLoadedClass(name);
if (c != null) return c;
if (name.startsWith("java.")) return super.loadClass(name, resolve);
if (!name.equals("com.esotericsoftware.reflectasm.ClassLoaderTest$Test"))
throw new ClassNotFoundException("Class not found on purpose: " + name);
ByteArrayOutputStream output = new ByteArrayOutputStream(32 * 1024);
InputStream input = ClassLoaderTest.class.getResourceAsStream("/" + name.replace('.', '/') + ".class");
if (input == null) return null;
try {
byte[] buffer = new byte[4096];
int total = 0;
while (true) {
int length = input.read(buffer, 0, buffer.length);
if (length == -1) break;
output.write(buffer, 0, length);
}
} catch (IOException ex) {
throw new ClassNotFoundException("Error reading class file.", ex);
} finally {
try {
input.close();
} catch (IOException ignored) {
}
}
byte[] buffer = output.toByteArray();
return defineClass(name, buffer, 0, buffer.length);
}
};
ClassLoader testClassLoader = new TestClassLoader1();
Class testClass = testClassLoader.loadClass("com.esotericsoftware.reflectasm.ClassLoaderTest$Test");
Object testObject = testClass.newInstance();
@ -49,6 +23,90 @@ public class ClassLoaderTest extends TestCase {
assertEquals("first", testObject.toString());
assertEquals("first", access.get(testObject, "name"));
}
public void testAutoUnloadClassloaders () throws Exception {
int initialCount = AccessClassLoader.activeAccessClassLoaders();
ClassLoader testClassLoader1 = new TestClassLoader1();
Class testClass1 = testClassLoader1.loadClass("com.esotericsoftware.reflectasm.ClassLoaderTest$Test");
Object testObject1 = testClass1.newInstance();
FieldAccess access1 = FieldAccess.get(testObject1.getClass());
access1.set(testObject1, "name", "first");
assertEquals("first", testObject1.toString());
assertEquals("first", access1.get(testObject1, "name"));
ClassLoader testClassLoader2 = new TestClassLoader2();
Class testClass2 = testClassLoader2.loadClass("com.esotericsoftware.reflectasm.ClassLoaderTest$Test");
Object testObject2 = testClass2.newInstance();
FieldAccess access2 = FieldAccess.get(testObject2.getClass());
access2.set(testObject2, "name", "second");
assertEquals("second", testObject2.toString());
assertEquals("second", access2.get(testObject2, "name"));
assertEquals(access1.getClass().toString(), access2.getClass().toString()); // Same class names
assertFalse(access1.getClass().equals(access2.getClass())); // But different classes
assertEquals(initialCount+2, AccessClassLoader.activeAccessClassLoaders());
testClassLoader1 = null;
testClass1 = null;
testObject1 = null;
access1 = null;
testClassLoader2 = null;
testClass2 = null;
testObject2 = null;
access2 = null;
// Force GC to reclaim unreachable (or only weak-reachable) objects
System.gc();
try {
Object[] array = new Object[(int) Runtime.getRuntime().maxMemory()];
System.out.println(array.length);
} catch (Throwable e) {
// Ignore OME
}
System.gc();
int times = 0;
while (AccessClassLoader.activeAccessClassLoaders()>1 && times < 50) { // max 5 seconds, should be instant
Thread.sleep(100); // test again
times++;
}
// Yeah, both reclaimed!
assertEquals(1, AccessClassLoader.activeAccessClassLoaders());
}
public void testRemoveClassloaders () throws Exception {
int initialCount = AccessClassLoader.activeAccessClassLoaders();
ClassLoader testClassLoader1 = new TestClassLoader1();
Class testClass1 = testClassLoader1.loadClass("com.esotericsoftware.reflectasm.ClassLoaderTest$Test");
Object testObject1 = testClass1.newInstance();
FieldAccess access1 = FieldAccess.get(testObject1.getClass());
access1.set(testObject1, "name", "first");
assertEquals("first", testObject1.toString());
assertEquals("first", access1.get(testObject1, "name"));
ClassLoader testClassLoader2 = new TestClassLoader2();
Class testClass2 = testClassLoader2.loadClass("com.esotericsoftware.reflectasm.ClassLoaderTest$Test");
Object testObject2 = testClass2.newInstance();
FieldAccess access2 = FieldAccess.get(testObject2.getClass());
access2.set(testObject2, "name", "second");
assertEquals("second", testObject2.toString());
assertEquals("second", access2.get(testObject2, "name"));
assertEquals(access1.getClass().toString(), access2.getClass().toString()); // Same class names
assertFalse(access1.getClass().equals(access2.getClass())); // But different classes
assertEquals(initialCount+2, AccessClassLoader.activeAccessClassLoaders());
AccessClassLoader.remove(testObject1.getClass().getClassLoader());
assertEquals(initialCount+1, AccessClassLoader.activeAccessClassLoaders());
AccessClassLoader.remove(testObject2.getClass().getClassLoader());
assertEquals(initialCount+0, AccessClassLoader.activeAccessClassLoaders());
AccessClassLoader.remove(this.getClass().getClassLoader());
assertEquals(initialCount-1, AccessClassLoader.activeAccessClassLoaders());
}
static public class Test {
public String name;
@ -57,4 +115,38 @@ public class ClassLoaderTest extends TestCase {
return name;
}
}
static public class TestClassLoader1 extends ClassLoader {
protected synchronized Class<?> loadClass (String name, boolean resolve) throws ClassNotFoundException {
Class c = findLoadedClass(name);
if (c != null) return c;
if (name.startsWith("java.")) return super.loadClass(name, resolve);
if (!name.equals("com.esotericsoftware.reflectasm.ClassLoaderTest$Test"))
throw new ClassNotFoundException("Class not found on purpose: " + name);
ByteArrayOutputStream output = new ByteArrayOutputStream(32 * 1024);
InputStream input = ClassLoaderTest.class.getResourceAsStream("/" + name.replace('.', '/') + ".class");
if (input == null) return null;
try {
byte[] buffer = new byte[4096];
int total = 0;
while (true) {
int length = input.read(buffer, 0, buffer.length);
if (length == -1) break;
output.write(buffer, 0, length);
}
} catch (IOException ex) {
throw new ClassNotFoundException("Error reading class file.", ex);
} finally {
try {
input.close();
} catch (IOException ignored) {
}
}
byte[] buffer = output.toByteArray();
return defineClass(name, buffer, 0, buffer.length);
}
}
static public class TestClassLoader2 extends TestClassLoader1 {
}
}

View File

@ -1,6 +1,8 @@
package com.esotericsoftware.reflectasm;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertTrue;
import junit.framework.TestCase;
public class ConstructorAccessTest extends TestCase {
@ -20,6 +22,70 @@ public class ConstructorAccessTest extends TestCase {
assertEquals(someObject, access.newInstance());
}
public void testHasArgumentConstructor () {
try {
ConstructorAccess.get(HasArgumentConstructor.class);
assertTrue(false);
}
catch (RuntimeException re) {
System.out.println("Expected exception happened: " + re);
}
catch (Throwable t) {
System.out.println("Unexpected exception happened: " + t);
assertTrue(false);
}
}
public void testHasPrivateConstructor () {
try {
ConstructorAccess.get(HasPrivateConstructor.class);
assertTrue(false);
}
catch (RuntimeException re) {
System.out.println("Expected exception happened: " + re);
}
catch (Throwable t) {
System.out.println("Unexpected exception happened: " + t);
assertTrue(false);
}
}
public void testHasProtectedConstructor () {
try {
ConstructorAccess<HasProtectedConstructor> access = ConstructorAccess.get(HasProtectedConstructor.class);
HasProtectedConstructor newInstance = access.newInstance();
assertEquals("cow", newInstance.getMoo());
}
catch (Throwable t) {
System.out.println("Unexpected exception happened: " + t);
assertTrue(false);
}
}
public void testHasPackageProtectedConstructor () {
try {
ConstructorAccess<HasPackageProtectedConstructor> access = ConstructorAccess.get(HasPackageProtectedConstructor.class);
HasPackageProtectedConstructor newInstance = access.newInstance();
assertEquals("cow", newInstance.getMoo());
}
catch (Throwable t) {
System.out.println("Unexpected exception happened: " + t);
assertTrue(false);
}
}
public void testHasPublicConstructor () {
try {
ConstructorAccess<HasPublicConstructor> access = ConstructorAccess.get(HasPublicConstructor.class);
HasPublicConstructor newInstance = access.newInstance();
assertEquals("cow", newInstance.getMoo());
}
catch (Throwable t) {
System.out.println("Unexpected exception happened: " + t);
assertTrue(false);
}
}
static class PackagePrivateClass {
public String name;
public int intValue;
@ -73,4 +139,51 @@ public class ConstructorAccessTest extends TestCase {
return true;
}
}
static public class HasArgumentConstructor {
public String moo;
public HasArgumentConstructor (String moo) {
this.moo = moo;
}
public boolean equals (Object obj) {
if (this == obj) return true;
if (obj == null) return false;
if (getClass() != obj.getClass()) return false;
HasArgumentConstructor other = (HasArgumentConstructor)obj;
if (moo == null) {
if (other.moo != null) return false;
} else if (!moo.equals(other.moo)) return false;
return true;
}
public String getMoo() {
return moo;
}
}
static public class HasPrivateConstructor extends HasArgumentConstructor {
private HasPrivateConstructor () {
super("cow");
}
}
static public class HasProtectedConstructor extends HasPrivateConstructor {
protected HasProtectedConstructor () {
super();
}
}
static public class HasPackageProtectedConstructor extends HasProtectedConstructor {
HasPackageProtectedConstructor () {
super();
}
}
static public class HasPublicConstructor extends HasPackageProtectedConstructor {
HasPublicConstructor () {
super();
}
}
}

View File

@ -1,7 +1,9 @@
package com.esotericsoftware.reflectasm;
import com.esotericsoftware.reflectasm.FieldAccessTest.EmptyClass;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import static junit.framework.Assert.assertEquals;
import junit.framework.TestCase;
@ -71,6 +73,21 @@ public class MethodAccessTest extends TestCase {
}
}
public void testInvokeInterface () {
MethodAccess access = MethodAccess.get(ConcurrentMap.class);
ConcurrentHashMap<String, String> someMap = new ConcurrentHashMap<String, String>();
someMap.put("first", "one");
someMap.put("second", "two");
Object value;
// invoke a method declared directly in the ConcurrentMap interface
value = access.invoke(someMap, "replace", "first", "foo");
assertEquals("one", value);
// invoke a method declared in the Map superinterface
value = access.invoke(someMap, "size");
assertEquals(someMap.size(), value);
}
static public class EmptyClass {
}
@ -98,4 +115,5 @@ public class MethodAccessTest extends TestCase {
return "test";
}
}
}