diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java index 9f70d25..3dc7db4 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java @@ -14,6 +14,8 @@ import org.objectweb.asm.tree.*; import java.util.List; +import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_ARRAY; +import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS; import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR; @@ -44,10 +46,13 @@ public class MockClassHandler extends BaseClassWithContextHandler { mn.access &= ~ACC_PRIVATE; mn.access &= ~ACC_PROTECTED; mn.access |= ACC_PUBLIC; - // below transform order is important + // firstly, unfold target class from annotation to parameter unfoldTargetClass(mn); + // secondly, add invoke recorder at the beginning of mock method injectInvokeRecorder(mn); + // thirdly, add association checker before invoke recorder injectAssociationChecker(mn); + // finally, handle testable util variables handleTestableUtil(mn); } } @@ -170,7 +175,8 @@ public class MockClassHandler extends BaseClassWithContextHandler { if (VOID_RES.equals(returnType)) { il.add(new InsnNode(POP)); il.add(new InsnNode(RETURN)); - } else if (returnType.startsWith("[") || returnType.startsWith("L")) { + } else if (returnType.startsWith(String.valueOf(TYPE_ARRAY)) || + returnType.startsWith(String.valueOf(TYPE_CLASS))) { il.add(new TypeInsnNode(CHECKCAST, returnType)); il.add(new InsnNode(ARETURN)); } else { @@ -187,13 +193,12 @@ public class MockClassHandler extends BaseClassWithContextHandler { Type className; String methodName = mn.name; for (AnnotationNode an : mn.visibleAnnotations) { - if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc)) { - String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD, - null, String.class); + if (isMockMethodAnnotation(an)) { + String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD, null, String.class); if (name != null) { methodName = name; } - } else if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) { + } else if (isMockConstructorAnnotation(an)) { methodName = CONSTRUCTOR; } } @@ -207,8 +212,7 @@ public class MockClassHandler extends BaseClassWithContextHandler { private boolean isGlobalScope(MethodNode mn) { for (AnnotationNode an : mn.visibleAnnotations) { - if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) || - ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) { + if (isMockMethodAnnotation(an) || isMockConstructorAnnotation(an)) { MockScope scope = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_SCOPE, GlobalConfig.getDefaultMockScope(), MockScope.class); if (scope.equals(MockScope.GLOBAL)) { @@ -224,14 +228,23 @@ public class MockClassHandler extends BaseClassWithContextHandler { return false; } for (AnnotationNode an : mn.visibleAnnotations) { - if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) || - ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) { + if (isMockMethodAnnotation(an) && AnnotationUtil.isValidMockMethod(mn, an)) { + return true; + } else if (isMockConstructorAnnotation(an)) { return true; } } return false; } + private boolean isMockConstructorAnnotation(AnnotationNode an) { + return ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc); + } + + private boolean isMockMethodAnnotation(AnnotationNode an) { + return ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc); + } + private void injectInvokeRecorder(MethodNode mn) { InsnList il = new InsnList(); il.add(duplicateParameters(mn)); diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/MockClassParser.java b/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/MockClassParser.java index 4464d28..f88b863 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/MockClassParser.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/MockClassParser.java @@ -16,6 +16,7 @@ import org.objectweb.asm.tree.MethodNode; import java.util.ArrayList; import java.util.List; +import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS; import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; import static com.alibaba.testable.agent.util.MethodUtil.isStatic; import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR; @@ -89,7 +90,7 @@ public class MockClassParser { LogUtil.verbose(" Mock constructor \"%s\" as \"(%s)V\" for \"%s\"", mn.name, MethodUtil.extractParameters(mn.desc), MethodUtil.getReturnType(mn.desc)); addMockConstructor(methodInfos, cn, mn); - } else if (fullClassName.equals(ConstPool.MOCK_METHOD)) { + } else if (fullClassName.equals(ConstPool.MOCK_METHOD) && AnnotationUtil.isValidMockMethod(mn, an)) { LogUtil.verbose(" Mock method \"%s\" as \"%s\"", mn.name, getTargetMethodDesc(mn, an)); String targetMethod = AnnotationUtil.getAnnotationParameter( an, ConstPool.FIELD_TARGET_METHOD, mn.name, String.class); diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java b/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java index 11594d4..b8e708f 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java @@ -1,6 +1,11 @@ package com.alibaba.testable.agent.util; +import com.alibaba.testable.agent.constant.ConstPool; +import org.objectweb.asm.Type; import org.objectweb.asm.tree.AnnotationNode; +import org.objectweb.asm.tree.MethodNode; + +import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS; /** * @author flin @@ -59,4 +64,16 @@ public class AnnotationUtil { } return false; } + + /** + * Check is MockMethod annotation is used on a valid mock method + * @param mn mock method + * @param an MockMethod annotation + * @return valid or not + */ + public static boolean isValidMockMethod(MethodNode mn, AnnotationNode an) { + Type targetClass = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class); + String firstParameter = MethodUtil.getFirstParameter(mn.desc); + return targetClass != null || firstParameter.startsWith(String.valueOf(TYPE_CLASS)); + } } diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/util/MethodUtil.java b/testable-agent/src/main/java/com/alibaba/testable/agent/util/MethodUtil.java index d470bed..d1013b6 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/util/MethodUtil.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/util/MethodUtil.java @@ -73,13 +73,13 @@ public class MethodUtil { } /** - * parse method desc, fetch first parameter type + * parse method desc, fetch first parameter type (assume first parameter is an object type) * @param desc method description * @return types of first parameter */ public static String getFirstParameter(String desc) { int typeEdge = desc.indexOf(CLASS_END); - return desc.substring(1, typeEdge + 1); + return typeEdge > 0 ? desc.substring(1, typeEdge + 1) : ""; } /** diff --git a/testable-agent/src/test/java/com/alibaba/testable/agent/util/MethodUtilTest.java b/testable-agent/src/test/java/com/alibaba/testable/agent/util/MethodUtilTest.java index 53fdd87..0f8dc4f 100644 --- a/testable-agent/src/test/java/com/alibaba/testable/agent/util/MethodUtilTest.java +++ b/testable-agent/src/test/java/com/alibaba/testable/agent/util/MethodUtilTest.java @@ -33,6 +33,8 @@ class MethodUtilTest { @Test void should_able_to_get_first_parameter() { assertEquals("Ljava/lang/String;", MethodUtil.getFirstParameter("(Ljava/lang/String;Ljava/lang/Object;I)V")); + assertEquals("Ljava/lang/String;", MethodUtil.getFirstParameter("(Ljava/lang/String;)V")); + assertEquals("", MethodUtil.getFirstParameter("()V")); } }