From 7754f828ccd173307fcf21f0d728d66f061f2dcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=87=91=E6=88=9F?= Date: Sat, 1 May 2021 16:05:42 +0800 Subject: [PATCH] fix issue of invoking method with single array parameter --- .../testable/core/tool/PrivateAccessor.java | 261 ++++++++++-------- .../core/tool/PrivateAccessorTest.java | 28 ++ 2 files changed, 170 insertions(+), 119 deletions(-) create mode 100644 testable-core/src/test/java/com/alibaba/testable/core/tool/PrivateAccessorTest.java diff --git a/testable-core/src/main/java/com/alibaba/testable/core/tool/PrivateAccessor.java b/testable-core/src/main/java/com/alibaba/testable/core/tool/PrivateAccessor.java index 5feb814..645d7fd 100644 --- a/testable-core/src/main/java/com/alibaba/testable/core/tool/PrivateAccessor.java +++ b/testable-core/src/main/java/com/alibaba/testable/core/tool/PrivateAccessor.java @@ -3,10 +3,7 @@ package com.alibaba.testable.core.tool; import com.alibaba.testable.core.exception.MemberAccessException; import com.alibaba.testable.core.util.TypeUtil; -import java.lang.reflect.Constructor; -import java.lang.reflect.Field; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; +import java.lang.reflect.*; /** * @author flin @@ -19,12 +16,127 @@ public class PrivateAccessor { /** * 读取任意类的私有字段 - * @param ref 目标对象 + * @param ref 目标对象 * @param fieldName 目标字段名 */ public static T get(Object ref, String fieldName) { + return get(ref, ref.getClass(), fieldName); + } + + /** + * 修改任意类的私有字段(或常量字段) + * @param ref 目标对象 + * @param fieldName 目标字段名 + * @param value 目标值 + */ + public static void set(Object ref, String fieldName, T value) { + set(ref, ref.getClass(), fieldName, value); + } + + /** + * 调用任意类的私有方法 + * @param ref 目标对象 + * @param method 目标方法名 + * @param args 方法参数 + */ + public static T invoke(Object ref, String method, Object... args) { + return invoke(ref, ref.getClass(), method, args); + } + + /** + * 读取任意类的静态私有字段 + * @param clazz 目标类型 + * @param fieldName 目标字段名 + */ + public static T getStatic(Class clazz, String fieldName) { + return get(null, clazz, fieldName); + } + + /** + * 修改任意类的静态私有字段(或静态常量字段) + * @param clazz 目标类型 + * @param fieldName 目标字段名 + * @param value 目标值 + */ + public static void setStatic(Class clazz, String fieldName, T value) { + set(null, clazz, fieldName, value); + } + + /** + * 调用任意类的静态私有方法 + * @param clazz 目标类型 + * @param method 目标方法名 + * @param args 方法参数 + */ + public static T invokeStatic(Class clazz, String method, Object... args) { + return invoke(null, clazz, method, args); + } + + /** + * 访问任意类的私有构造方法 + * @param clazz 目标类型 + * @param args 构造方法参数 + */ + public static T construct(Class clazz, Object... args) { try { - Field field = TypeUtil.getFieldByName(ref.getClass(), fieldName); + Constructor constructor = TypeUtil.getConstructorByParameterTypes(clazz, + TypeUtil.getClassesFromObjects(args)); + if (constructor != null) { + constructor.setAccessible(true); + return (T)constructor.newInstance(args); + } + } catch (IllegalAccessException e) { + throw new MemberAccessException("Failed to access private constructor of \"" + + clazz.getSimpleName() + "\"", e); + } catch (InvocationTargetException e) { + if (e.getTargetException() instanceof RuntimeException) { + throw (RuntimeException)e.getTargetException(); + } + throw new MemberAccessException( + "Invoke private constructor of \"" + clazz.getSimpleName() + "\" failed with exception", e); + } catch (InstantiationException e) { + throw new MemberAccessException("Failed to instantiate object of \"" + clazz.getSimpleName() + "\"", e); + } + throw new MemberAccessException("Private constructor of \"" + clazz.getSimpleName() + "\" not exist"); + } + + /** + * 获取所有类型的公共父类 + */ + private static Class getCommonParentClass(Class[] cls) { + if (cls.length < 2 || cls[0] == null || cls[0].isPrimitive()) { + return null; + } + Class commonClass = cls[0]; + for (int i = 1; i < cls.length; i++) { + if (cls[i].isPrimitive()) { + return null; + } else if (cls[i] == null) { + continue; + } + commonClass = getCommonClassOf(commonClass, cls[i]); + } + return commonClass; + } + + /** + * 获取两个类的公共父类 + */ + private static Class getCommonClassOf(Class cls1, Class cls2) { + if (cls1.isAssignableFrom(cls2)) { + return cls1; + } else if (cls2.isAssignableFrom(cls1)) { + return cls2; + } else if (cls1.getSuperclass().equals(Object.class) || cls2.getSuperclass().equals(Object.class)) { + return Object.class; + } else { + return getCommonClassOf(cls1.getSuperclass(), cls2.getSuperclass()); + } + } + + private static T get(Object ref, Class clazz, String fieldName) { + try { + Field field = TypeUtil.getFieldByName(clazz, fieldName); if (field == null) { throw new MemberAccessException("Private field \"" + fieldName + "\" not exist"); } @@ -35,15 +147,9 @@ public class PrivateAccessor { } } - /** - * 修改任意类的私有字段(或常量字段) - * @param ref 目标对象 - * @param fieldName 目标字段名 - * @param value 目标值 - */ - public static void set(Object ref, String fieldName, T value) { + private static void set(Object ref, Class clazz, String fieldName, T value) { try { - Field field = TypeUtil.getFieldByName(ref.getClass(), fieldName); + Field field = TypeUtil.getFieldByName(clazz, fieldName); if (field == null) { throw new MemberAccessException("Private field \"" + fieldName + "\" not exist"); } @@ -54,22 +160,36 @@ public class PrivateAccessor { } } - /** - * 调用任意类的私有方法 - * @param ref 目标对象 - * @param method 目标方法名 - * @param args 方法参数 - */ - public static T invoke(Object ref, String method, Object... args) { + private static T invoke(Object ref, Class clazz, String method, Object... args) { try { Class[] cls = TypeUtil.getClassesFromObjects(args); - Method declaredMethod = TypeUtil.getMethodByNameAndParameterTypes(ref.getClass(), method, cls); + Class commonClass = getCommonParentClass(cls); + Method declaredMethod = TypeUtil.getMethodByNameAndParameterTypes(clazz, method, cls); if (declaredMethod != null) { declaredMethod.setAccessible(true); return (T)declaredMethod.invoke(ref, args); + } else if (commonClass != null) { + Class arrayType = Array.newInstance(commonClass, 0).getClass(); + declaredMethod = TypeUtil.getMethodByNameAndParameterTypes(clazz, method, new Class[] {arrayType}); + if (declaredMethod != null) { + declaredMethod.setAccessible(true); + return (T)declaredMethod.invoke(ref, new Object[] {args}); + } + } + if (ref == null) { + // fit kotlin companion object, will throw 'NoSuchFieldException' otherwise + Field companionClassField = clazz.getDeclaredField(KOTLIN_COMPANION_FIELD); + declaredMethod = TypeUtil.getMethodByNameAndParameterTypes(companionClassField.getType(), method, cls); + Object companionInstance = getStatic(clazz, KOTLIN_COMPANION_FIELD); + if (declaredMethod != null && companionInstance != null) { + declaredMethod.setAccessible(true); + return (T)declaredMethod.invoke(companionInstance, args); + } } } catch (IllegalAccessException e) { throw new MemberAccessException("Failed to access private method \"" + method + "\"", e); + } catch (NoSuchFieldException e) { + throw new MemberAccessException("Private method \"" + method + "\" not exist"); } catch (InvocationTargetException e) { if (e.getTargetException() instanceof RuntimeException) { throw (RuntimeException)e.getTargetException(); @@ -78,101 +198,4 @@ public class PrivateAccessor { } throw new MemberAccessException("Private method \"" + method + "\" not exist"); } - - /** - * 读取任意类的静态私有字段 - * @param clazz 目标类型 - * @param fieldName 目标字段名 - */ - public static T getStatic(Class clazz, String fieldName) { - try { - Field field = TypeUtil.getFieldByName(clazz, fieldName); - if (field == null) { - throw new MemberAccessException("Private static field \"" + fieldName + "\" not exist"); - } - field.setAccessible(true); - return (T)field.get(null); - } catch (IllegalAccessException e) { - throw new MemberAccessException("Failed to access private static field \"" + fieldName + "\"", e); - } - } - - /** - * 修改任意类的静态私有字段(或静态常量字段) - * @param clazz 目标类型 - * @param fieldName 目标字段名 - * @param value 目标值 - */ - public static void setStatic(Class clazz, String fieldName, T value) { - try { - Field field = TypeUtil.getFieldByName(clazz, fieldName); - if (field == null) { - throw new MemberAccessException("Private static field \"" + fieldName + "\" not exist"); - } - field.setAccessible(true); - field.set(null, value); - } catch (IllegalAccessException e) { - throw new MemberAccessException("Failed to access private static field \"" + fieldName + "\"", e); - } - } - - /** - * 调用任意类的静态私有方法 - * @param clazz 目标类型 - * @param method 目标方法名 - * @param args 方法参数 - */ - public static T invokeStatic(Class clazz, String method, Object... args) { - try { - Class[] cls = TypeUtil.getClassesFromObjects(args); - Method declaredMethod = TypeUtil.getMethodByNameAndParameterTypes(clazz, method, cls); - if (declaredMethod != null) { - declaredMethod.setAccessible(true); - return (T)declaredMethod.invoke(null, args); - } - // fit kotlin companion object, will throw 'NoSuchFieldException' otherwise - Field companionClassField = clazz.getDeclaredField(KOTLIN_COMPANION_FIELD); - declaredMethod = TypeUtil.getMethodByNameAndParameterTypes(companionClassField.getType(), method, cls); - Object companionInstance = getStatic(clazz, KOTLIN_COMPANION_FIELD); - if (declaredMethod != null && companionInstance != null) { - declaredMethod.setAccessible(true); - return (T)declaredMethod.invoke(companionInstance, args); - } - } catch (IllegalAccessException e) { - throw new MemberAccessException("Failed to access private static method \"" + method + "\"", e); - } catch (NoSuchFieldException e) { - throw new MemberAccessException("Private static method \"" + method + "\" not exist"); - } catch (InvocationTargetException e) { - if (e.getTargetException() instanceof RuntimeException) { - throw (RuntimeException)e.getTargetException(); - } - throw new MemberAccessException("Invoke private static method \"" + method + "\" failed with exception", e); - } - throw new MemberAccessException("Neither Private static method nor companion method \"" + method + "\" exist"); - } - - /** - * 访问任意类的私有构造方法 - * @param clazz 目标类型 - * @param args 构造方法参数 - */ - public static T construct(Class clazz, Object... args) { - try { - Constructor constructor = TypeUtil.getConstructorByParameterTypes(clazz, TypeUtil.getClassesFromObjects(args)); - if (constructor != null) { - constructor.setAccessible(true); - return (T)constructor.newInstance(args); - } - } catch (IllegalAccessException e) { - throw new MemberAccessException("Failed to access private constructor of \"" + clazz.getSimpleName() + "\"", e); - } catch (InvocationTargetException e) { - if (e.getTargetException() instanceof RuntimeException) { - throw (RuntimeException)e.getTargetException(); - } - throw new MemberAccessException("Invoke private constructor of \"" + clazz.getSimpleName() + "\" failed with exception", e); - } catch (InstantiationException e) { - throw new MemberAccessException("Failed to instantiate object of \"" + clazz.getSimpleName() + "\"", e); - } - throw new MemberAccessException("Private constructor of \"" + clazz.getSimpleName() + "\" not exist"); - } } diff --git a/testable-core/src/test/java/com/alibaba/testable/core/tool/PrivateAccessorTest.java b/testable-core/src/test/java/com/alibaba/testable/core/tool/PrivateAccessorTest.java new file mode 100644 index 0000000..abc300d --- /dev/null +++ b/testable-core/src/test/java/com/alibaba/testable/core/tool/PrivateAccessorTest.java @@ -0,0 +1,28 @@ +package com.alibaba.testable.core.tool; + +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; + +import static org.junit.jupiter.api.Assertions.*; + +class PrivateAccessorTest { + + static class A {} + static class AB extends A {} + static class AC extends A {} + static class ABC extends AB {} + static class B {} + + @Test + void should_get_common_type() throws Exception { + Method getCommonClassOf = PrivateAccessor.class.getDeclaredMethod("getCommonClassOf", Class.class, Class.class); + getCommonClassOf.setAccessible(true); + assertEquals(A.class, getCommonClassOf.invoke(null, A.class, A.class)); + assertEquals(A.class, getCommonClassOf.invoke(null, A.class, AB.class)); + assertEquals(A.class, getCommonClassOf.invoke(null, ABC.class, A.class)); + assertEquals(Object.class, getCommonClassOf.invoke(null, B.class, A.class)); + assertEquals(A.class, getCommonClassOf.invoke(null, ABC.class, AC.class)); + } + +}