validate mock method before use

This commit is contained in:
金戟 2021-02-25 13:48:33 +08:00
parent 9e7ceb2dc1
commit 24c6a9cc5c
5 changed files with 46 additions and 13 deletions

View File

@ -14,6 +14,8 @@ import org.objectweb.asm.tree.*;
import java.util.List; import java.util.List;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_ARRAY;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR; import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
@ -44,10 +46,13 @@ public class MockClassHandler extends BaseClassWithContextHandler {
mn.access &= ~ACC_PRIVATE; mn.access &= ~ACC_PRIVATE;
mn.access &= ~ACC_PROTECTED; mn.access &= ~ACC_PROTECTED;
mn.access |= ACC_PUBLIC; mn.access |= ACC_PUBLIC;
// below transform order is important // firstly, unfold target class from annotation to parameter
unfoldTargetClass(mn); unfoldTargetClass(mn);
// secondly, add invoke recorder at the beginning of mock method
injectInvokeRecorder(mn); injectInvokeRecorder(mn);
// thirdly, add association checker before invoke recorder
injectAssociationChecker(mn); injectAssociationChecker(mn);
// finally, handle testable util variables
handleTestableUtil(mn); handleTestableUtil(mn);
} }
} }
@ -170,7 +175,8 @@ public class MockClassHandler extends BaseClassWithContextHandler {
if (VOID_RES.equals(returnType)) { if (VOID_RES.equals(returnType)) {
il.add(new InsnNode(POP)); il.add(new InsnNode(POP));
il.add(new InsnNode(RETURN)); il.add(new InsnNode(RETURN));
} else if (returnType.startsWith("[") || returnType.startsWith("L")) { } else if (returnType.startsWith(String.valueOf(TYPE_ARRAY)) ||
returnType.startsWith(String.valueOf(TYPE_CLASS))) {
il.add(new TypeInsnNode(CHECKCAST, returnType)); il.add(new TypeInsnNode(CHECKCAST, returnType));
il.add(new InsnNode(ARETURN)); il.add(new InsnNode(ARETURN));
} else { } else {
@ -187,13 +193,12 @@ public class MockClassHandler extends BaseClassWithContextHandler {
Type className; Type className;
String methodName = mn.name; String methodName = mn.name;
for (AnnotationNode an : mn.visibleAnnotations) { for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc)) { if (isMockMethodAnnotation(an)) {
String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD, String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD, null, String.class);
null, String.class);
if (name != null) { if (name != null) {
methodName = name; methodName = name;
} }
} else if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) { } else if (isMockConstructorAnnotation(an)) {
methodName = CONSTRUCTOR; methodName = CONSTRUCTOR;
} }
} }
@ -207,8 +212,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
private boolean isGlobalScope(MethodNode mn) { private boolean isGlobalScope(MethodNode mn) {
for (AnnotationNode an : mn.visibleAnnotations) { for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) || if (isMockMethodAnnotation(an) || isMockConstructorAnnotation(an)) {
ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) {
MockScope scope = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_SCOPE, MockScope scope = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_SCOPE,
GlobalConfig.getDefaultMockScope(), MockScope.class); GlobalConfig.getDefaultMockScope(), MockScope.class);
if (scope.equals(MockScope.GLOBAL)) { if (scope.equals(MockScope.GLOBAL)) {
@ -224,14 +228,23 @@ public class MockClassHandler extends BaseClassWithContextHandler {
return false; return false;
} }
for (AnnotationNode an : mn.visibleAnnotations) { for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) || if (isMockMethodAnnotation(an) && AnnotationUtil.isValidMockMethod(mn, an)) {
ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) { return true;
} else if (isMockConstructorAnnotation(an)) {
return true; return true;
} }
} }
return false; return false;
} }
private boolean isMockConstructorAnnotation(AnnotationNode an) {
return ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc);
}
private boolean isMockMethodAnnotation(AnnotationNode an) {
return ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc);
}
private void injectInvokeRecorder(MethodNode mn) { private void injectInvokeRecorder(MethodNode mn) {
InsnList il = new InsnList(); InsnList il = new InsnList();
il.add(duplicateParameters(mn)); il.add(duplicateParameters(mn));

View File

@ -16,6 +16,7 @@ import org.objectweb.asm.tree.MethodNode;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
import static com.alibaba.testable.agent.util.MethodUtil.isStatic; import static com.alibaba.testable.agent.util.MethodUtil.isStatic;
import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR; import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
@ -89,7 +90,7 @@ public class MockClassParser {
LogUtil.verbose(" Mock constructor \"%s\" as \"(%s)V\" for \"%s\"", mn.name, LogUtil.verbose(" Mock constructor \"%s\" as \"(%s)V\" for \"%s\"", mn.name,
MethodUtil.extractParameters(mn.desc), MethodUtil.getReturnType(mn.desc)); MethodUtil.extractParameters(mn.desc), MethodUtil.getReturnType(mn.desc));
addMockConstructor(methodInfos, cn, mn); addMockConstructor(methodInfos, cn, mn);
} else if (fullClassName.equals(ConstPool.MOCK_METHOD)) { } else if (fullClassName.equals(ConstPool.MOCK_METHOD) && AnnotationUtil.isValidMockMethod(mn, an)) {
LogUtil.verbose(" Mock method \"%s\" as \"%s\"", mn.name, getTargetMethodDesc(mn, an)); LogUtil.verbose(" Mock method \"%s\" as \"%s\"", mn.name, getTargetMethodDesc(mn, an));
String targetMethod = AnnotationUtil.getAnnotationParameter( String targetMethod = AnnotationUtil.getAnnotationParameter(
an, ConstPool.FIELD_TARGET_METHOD, mn.name, String.class); an, ConstPool.FIELD_TARGET_METHOD, mn.name, String.class);

View File

@ -1,6 +1,11 @@
package com.alibaba.testable.agent.util; package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.MethodNode;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS;
/** /**
* @author flin * @author flin
@ -59,4 +64,16 @@ public class AnnotationUtil {
} }
return false; return false;
} }
/**
* Check is MockMethod annotation is used on a valid mock method
* @param mn mock method
* @param an MockMethod annotation
* @return valid or not
*/
public static boolean isValidMockMethod(MethodNode mn, AnnotationNode an) {
Type targetClass = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
String firstParameter = MethodUtil.getFirstParameter(mn.desc);
return targetClass != null || firstParameter.startsWith(String.valueOf(TYPE_CLASS));
}
} }

View File

@ -73,13 +73,13 @@ public class MethodUtil {
} }
/** /**
* parse method desc, fetch first parameter type * parse method desc, fetch first parameter type (assume first parameter is an object type)
* @param desc method description * @param desc method description
* @return types of first parameter * @return types of first parameter
*/ */
public static String getFirstParameter(String desc) { public static String getFirstParameter(String desc) {
int typeEdge = desc.indexOf(CLASS_END); int typeEdge = desc.indexOf(CLASS_END);
return desc.substring(1, typeEdge + 1); return typeEdge > 0 ? desc.substring(1, typeEdge + 1) : "";
} }
/** /**

View File

@ -33,6 +33,8 @@ class MethodUtilTest {
@Test @Test
void should_able_to_get_first_parameter() { void should_able_to_get_first_parameter() {
assertEquals("Ljava/lang/String;", MethodUtil.getFirstParameter("(Ljava/lang/String;Ljava/lang/Object;I)V")); assertEquals("Ljava/lang/String;", MethodUtil.getFirstParameter("(Ljava/lang/String;Ljava/lang/Object;I)V"));
assertEquals("Ljava/lang/String;", MethodUtil.getFirstParameter("(Ljava/lang/String;)V"));
assertEquals("", MethodUtil.getFirstParameter("()V"));
} }
} }