validate mock method before use

This commit is contained in:
金戟 2021-02-25 13:48:33 +08:00
parent 9e7ceb2dc1
commit 24c6a9cc5c
5 changed files with 46 additions and 13 deletions

View File

@ -14,6 +14,8 @@ import org.objectweb.asm.tree.*;
import java.util.List;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_ARRAY;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
@ -44,10 +46,13 @@ public class MockClassHandler extends BaseClassWithContextHandler {
mn.access &= ~ACC_PRIVATE;
mn.access &= ~ACC_PROTECTED;
mn.access |= ACC_PUBLIC;
// below transform order is important
// firstly, unfold target class from annotation to parameter
unfoldTargetClass(mn);
// secondly, add invoke recorder at the beginning of mock method
injectInvokeRecorder(mn);
// thirdly, add association checker before invoke recorder
injectAssociationChecker(mn);
// finally, handle testable util variables
handleTestableUtil(mn);
}
}
@ -170,7 +175,8 @@ public class MockClassHandler extends BaseClassWithContextHandler {
if (VOID_RES.equals(returnType)) {
il.add(new InsnNode(POP));
il.add(new InsnNode(RETURN));
} else if (returnType.startsWith("[") || returnType.startsWith("L")) {
} else if (returnType.startsWith(String.valueOf(TYPE_ARRAY)) ||
returnType.startsWith(String.valueOf(TYPE_CLASS))) {
il.add(new TypeInsnNode(CHECKCAST, returnType));
il.add(new InsnNode(ARETURN));
} else {
@ -187,13 +193,12 @@ public class MockClassHandler extends BaseClassWithContextHandler {
Type className;
String methodName = mn.name;
for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc)) {
String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD,
null, String.class);
if (isMockMethodAnnotation(an)) {
String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD, null, String.class);
if (name != null) {
methodName = name;
}
} else if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) {
} else if (isMockConstructorAnnotation(an)) {
methodName = CONSTRUCTOR;
}
}
@ -207,8 +212,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
private boolean isGlobalScope(MethodNode mn) {
for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) ||
ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) {
if (isMockMethodAnnotation(an) || isMockConstructorAnnotation(an)) {
MockScope scope = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_SCOPE,
GlobalConfig.getDefaultMockScope(), MockScope.class);
if (scope.equals(MockScope.GLOBAL)) {
@ -224,14 +228,23 @@ public class MockClassHandler extends BaseClassWithContextHandler {
return false;
}
for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) ||
ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) {
if (isMockMethodAnnotation(an) && AnnotationUtil.isValidMockMethod(mn, an)) {
return true;
} else if (isMockConstructorAnnotation(an)) {
return true;
}
}
return false;
}
private boolean isMockConstructorAnnotation(AnnotationNode an) {
return ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc);
}
private boolean isMockMethodAnnotation(AnnotationNode an) {
return ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc);
}
private void injectInvokeRecorder(MethodNode mn) {
InsnList il = new InsnList();
il.add(duplicateParameters(mn));

View File

@ -16,6 +16,7 @@ import org.objectweb.asm.tree.MethodNode;
import java.util.ArrayList;
import java.util.List;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
import static com.alibaba.testable.agent.util.MethodUtil.isStatic;
import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
@ -89,7 +90,7 @@ public class MockClassParser {
LogUtil.verbose(" Mock constructor \"%s\" as \"(%s)V\" for \"%s\"", mn.name,
MethodUtil.extractParameters(mn.desc), MethodUtil.getReturnType(mn.desc));
addMockConstructor(methodInfos, cn, mn);
} else if (fullClassName.equals(ConstPool.MOCK_METHOD)) {
} else if (fullClassName.equals(ConstPool.MOCK_METHOD) && AnnotationUtil.isValidMockMethod(mn, an)) {
LogUtil.verbose(" Mock method \"%s\" as \"%s\"", mn.name, getTargetMethodDesc(mn, an));
String targetMethod = AnnotationUtil.getAnnotationParameter(
an, ConstPool.FIELD_TARGET_METHOD, mn.name, String.class);

View File

@ -1,6 +1,11 @@
package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.MethodNode;
import static com.alibaba.testable.agent.constant.ByteCodeConst.TYPE_CLASS;
/**
* @author flin
@ -59,4 +64,16 @@ public class AnnotationUtil {
}
return false;
}
/**
* Check is MockMethod annotation is used on a valid mock method
* @param mn mock method
* @param an MockMethod annotation
* @return valid or not
*/
public static boolean isValidMockMethod(MethodNode mn, AnnotationNode an) {
Type targetClass = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
String firstParameter = MethodUtil.getFirstParameter(mn.desc);
return targetClass != null || firstParameter.startsWith(String.valueOf(TYPE_CLASS));
}
}

View File

@ -73,13 +73,13 @@ public class MethodUtil {
}
/**
* parse method desc, fetch first parameter type
* parse method desc, fetch first parameter type (assume first parameter is an object type)
* @param desc method description
* @return types of first parameter
*/
public static String getFirstParameter(String desc) {
int typeEdge = desc.indexOf(CLASS_END);
return desc.substring(1, typeEdge + 1);
return typeEdge > 0 ? desc.substring(1, typeEdge + 1) : "";
}
/**

View File

@ -33,6 +33,8 @@ class MethodUtilTest {
@Test
void should_able_to_get_first_parameter() {
assertEquals("Ljava/lang/String;", MethodUtil.getFirstParameter("(Ljava/lang/String;Ljava/lang/Object;I)V"));
assertEquals("Ljava/lang/String;", MethodUtil.getFirstParameter("(Ljava/lang/String;)V"));
assertEquals("", MethodUtil.getFirstParameter("()V"));
}
}