diff --git a/src/com/esotericsoftware/reflectasm/MethodAccess.java b/src/com/esotericsoftware/reflectasm/MethodAccess.java index da4cbd3..2bf8f0e 100644 --- a/src/com/esotericsoftware/reflectasm/MethodAccess.java +++ b/src/com/esotericsoftware/reflectasm/MethodAccess.java @@ -20,9 +20,9 @@ public abstract class MethodAccess { abstract public Object invoke (Object object, int methodIndex, Object... args); - /** Invokes the first method with the specified name. */ + /** Invokes the first method with the specified name and the specified number of arguments. */ public Object invoke (Object object, String methodName, Object... args) { - return invoke(object, getIndex(methodName), args); + return invoke(object, getIndex(methodName, args.length), args); } /** Returns the index of the first method with the specified name. */ @@ -32,10 +32,18 @@ public abstract class MethodAccess { throw new IllegalArgumentException("Unable to find public method: " + methodName); } + /** Returns the index of the first method with the specified name and param types. */ public int getIndex (String methodName, Class... paramTypes) { for (int i = 0, n = methodNames.length; i < n; i++) if (methodNames[i].equals(methodName) && Arrays.equals(paramTypes, parameterTypes[i])) return i; - throw new IllegalArgumentException("Unable to find public method: " + methodName + " " + Arrays.toString(parameterTypes)); + throw new IllegalArgumentException("Unable to find public method: " + methodName + " " + Arrays.toString(paramTypes)); + } + + /** Returns the index of the first method with the specified name and the specified number of arguments. */ + public int getIndex (String methodName, int paramsCount) { + for (int i = 0, n = methodNames.length; i < n; i++) + if (methodNames[i].equals(methodName) && parameterTypes[i].length==paramsCount) return i; + throw new IllegalArgumentException("Unable to find public method: " + methodName + " with " + paramsCount + " params."); } public String[] getMethodNames () { @@ -47,18 +55,17 @@ public abstract class MethodAccess { } static public MethodAccess get (Class type) { - ArrayList methods = new ArrayList(); - Class nextClass = type; - while (nextClass != Object.class) { - Method[] declaredMethods = nextClass.getDeclaredMethods(); - for (int i = 0, n = declaredMethods.length; i < n; i++) { - Method method = declaredMethods[i]; - int modifiers = method.getModifiers(); - if (Modifier.isStatic(modifiers)) continue; - if (Modifier.isPrivate(modifiers)) continue; - methods.add(method); + ArrayList methods = new ArrayList(); + boolean isInterface = type.isInterface(); + if (!isInterface) { + Class nextClass = type; + while (nextClass != Object.class) { + addDeclaredMethodsToList(nextClass, methods); + nextClass = nextClass.getSuperclass(); } - nextClass = nextClass.getSuperclass(); + } + else { + recursiveAddInterfaceMethodsToList(type, methods); } Class[][] parameterTypes = new Class[methods.size()][]; @@ -72,7 +79,7 @@ public abstract class MethodAccess { String className = type.getName(); String accessClassName = className + "MethodAccess"; if (accessClassName.startsWith("java.")) accessClassName = "reflectasm." + accessClassName; - Class accessClass = null; + Class accessClass; AccessClassLoader loader = AccessClassLoader.get(type); synchronized (loader) { @@ -176,7 +183,7 @@ public abstract class MethodAccess { buffer.append(')'); buffer.append(Type.getDescriptor(method.getReturnType())); - mv.visitMethodInsn(INVOKEVIRTUAL, classNameInternal, method.getName(), buffer.toString()); + mv.visitMethodInsn(isInterface ? INVOKEINTERFACE : INVOKEVIRTUAL, classNameInternal, method.getName(), buffer.toString()); switch (Type.getType(method.getReturnType()).getSort()) { case Type.VOID: @@ -242,4 +249,22 @@ public abstract class MethodAccess { throw new RuntimeException("Error constructing method access class: " + accessClassName, ex); } } + + private static void addDeclaredMethodsToList(Class type, ArrayList methods) { + Method[] declaredMethods = type.getDeclaredMethods(); + for (int i = 0, n = declaredMethods.length; i < n; i++) { + Method method = declaredMethods[i]; + int modifiers = method.getModifiers(); + if (Modifier.isStatic(modifiers)) continue; + if (Modifier.isPrivate(modifiers)) continue; + methods.add(method); + } + } + + private static void recursiveAddInterfaceMethodsToList(Class interfaceType, ArrayList methods) { + addDeclaredMethodsToList(interfaceType, methods); + for (Class nextInterface : interfaceType.getInterfaces()) { + recursiveAddInterfaceMethodsToList(nextInterface, methods); + } + } }