invoke origin method if not associated

This commit is contained in:
金戟 2021-02-17 22:59:24 +08:00
parent b4a29e7e8e
commit 03ec857cd0
5 changed files with 161 additions and 40 deletions

View File

@ -11,7 +11,6 @@ import org.objectweb.asm.tree.*;
import java.util.List; import java.util.List;
import static com.alibaba.testable.agent.constant.ConstPool.CONSTRUCTOR;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
/** /**
@ -23,7 +22,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
private static final String CLASS_MOCK_ASSOCIATION_UTIL = "com/alibaba/testable/core/util/MockAssociationUtil"; private static final String CLASS_MOCK_ASSOCIATION_UTIL = "com/alibaba/testable/core/util/MockAssociationUtil";
private static final String METHOD_INVOKE_ORIGIN = "invokeOrigin"; private static final String METHOD_INVOKE_ORIGIN = "invokeOrigin";
private static final String SIGNATURE_INVOKE_ORIGIN = private static final String SIGNATURE_INVOKE_ORIGIN =
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/Object;"; "(Ljava/lang/Class;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/Object;";
private static final String METHOD_RECORD_MOCK_INVOKE = "recordMockInvoke"; private static final String METHOD_RECORD_MOCK_INVOKE = "recordMockInvoke";
private static final String SIGNATURE_RECORDER_METHOD_INVOKE = "([Ljava/lang/Object;Z)V"; private static final String SIGNATURE_RECORDER_METHOD_INVOKE = "([Ljava/lang/Object;Z)V";
private static final String METHOD_IS_ASSOCIATED = "isAssociated"; private static final String METHOD_IS_ASSOCIATED = "isAssociated";
@ -61,7 +60,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
il.add(new JumpInsnNode(IFNONNULL, label)); il.add(new JumpInsnNode(IFNONNULL, label));
il.add(new TypeInsnNode(NEW, mockClassName)); il.add(new TypeInsnNode(NEW, mockClassName));
il.add(new InsnNode(DUP)); il.add(new InsnNode(DUP));
il.add(new MethodInsnNode(INVOKESPECIAL, mockClassName, CONSTRUCTOR, VOID_ARGS + VOID_RES, false)); il.add(new MethodInsnNode(INVOKESPECIAL, mockClassName, ConstPool.CONSTRUCTOR, VOID_ARGS + VOID_RES, false));
il.add(new FieldInsnNode(PUTSTATIC, mockClassName, TESTABLE_REF, ClassUtil.toByteCodeClassName(mockClassName))); il.add(new FieldInsnNode(PUTSTATIC, mockClassName, TESTABLE_REF, ClassUtil.toByteCodeClassName(mockClassName)));
il.add(label); il.add(label);
il.add(new FrameNode(F_SAME, 0, null, 0, null)); il.add(new FrameNode(F_SAME, 0, null, 0, null));
@ -123,9 +122,52 @@ public class MockClassHandler extends BaseClassWithContextHandler {
private InsnList invokeOriginalMethod(MethodNode mn) { private InsnList invokeOriginalMethod(MethodNode mn) {
InsnList il = new InsnList(); InsnList il = new InsnList();
mn.maxStack += 3;
ImmutablePair<Type, String> target = getTargetClassAndMethodName(mn);
il.add(new LdcInsnNode(target.left));
il.add(new LdcInsnNode(target.right));
il.add(duplicateParameters(mn));
il.add(new MethodInsnNode(INVOKESTATIC, CLASS_MOCK_ASSOCIATION_UTIL, METHOD_INVOKE_ORIGIN,
SIGNATURE_INVOKE_ORIGIN, false));
String returnType = ClassUtil.getReturnType(mn.desc);
if (VOID_RES.equals(returnType)) {
il.add(new InsnNode(POP));
il.add(new InsnNode(RETURN));
} else if (returnType.startsWith("[") || returnType.startsWith("L")) {
il.add(new TypeInsnNode(CHECKCAST, returnType));
il.add(new InsnNode(ARETURN));
} else {
String wrapperClass = ClassUtil.toWrapperClass(returnType.getBytes()[0]);
il.add(new TypeInsnNode(CHECKCAST, wrapperClass));
ImmutablePair<String, String> convertMethod = ClassUtil.getWrapperTypeConvertMethod(returnType.getBytes()[0]);
il.add(new MethodInsnNode(INVOKEVIRTUAL, wrapperClass, convertMethod.left, convertMethod.right, false));
il.add(new InsnNode(ClassUtil.getReturnOpsCode(returnType)));
}
return il; return il;
} }
private ImmutablePair<Type, String> getTargetClassAndMethodName(MethodNode mn) {
Type className;
String methodName = mn.name;
for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc)) {
String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD,
null, String.class);
if (name != null) {
methodName = name;
}
} else if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) {
methodName = ConstPool.CONSTRUCTOR;
}
}
if (methodName.equals(ConstPool.CONSTRUCTOR)) {
className = Type.getType(ClassUtil.getReturnType(mn.desc));
} else {
className = Type.getType(ClassUtil.getFirstParameter(mn.desc));
}
return ImmutablePair.of(className, methodName);
}
private boolean isGlobalScope(MethodNode mn) { private boolean isGlobalScope(MethodNode mn) {
for (AnnotationNode an : mn.visibleAnnotations) { for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) || if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) ||
@ -140,15 +182,6 @@ public class MockClassHandler extends BaseClassWithContextHandler {
return false; return false;
} }
private LabelNode getFirstLabel(MethodNode mn) {
for (AbstractInsnNode n : mn.instructions) {
if (n instanceof LabelNode) {
return (LabelNode)n;
}
}
return null;
}
private boolean isMockMethod(MethodNode mn) { private boolean isMockMethod(MethodNode mn) {
if (mn.visibleAnnotations == null) { if (mn.visibleAnnotations == null) {
return false; return false;
@ -163,13 +196,26 @@ public class MockClassHandler extends BaseClassWithContextHandler {
} }
private void injectInvokeRecorder(MethodNode mn) { private void injectInvokeRecorder(MethodNode mn) {
InsnList il = new InsnList();
mn.maxStack += 2;
il.add(duplicateParameters(mn));
if (isMockForConstructor(mn)) {
il.add(new InsnNode(ICONST_1));
} else {
il.add(new InsnNode(ICONST_0));
}
il.add(new MethodInsnNode(INVOKESTATIC, CLASS_INVOKE_RECORD_UTIL, METHOD_RECORD_MOCK_INVOKE,
SIGNATURE_RECORDER_METHOD_INVOKE, false));
mn.instructions.insertBefore(mn.instructions.getFirst(), il);
}
private InsnList duplicateParameters(MethodNode mn) {
InsnList il = new InsnList(); InsnList il = new InsnList();
List<Byte> types = ClassUtil.getParameterTypes(mn.desc); List<Byte> types = ClassUtil.getParameterTypes(mn.desc);
int size = types.size(); int size = types.size();
int parameterOffset = 1;
mn.maxStack += 2;
il.add(getIntInsn(size)); il.add(getIntInsn(size));
il.add(new TypeInsnNode(ANEWARRAY, ClassUtil.CLASS_OBJECT)); il.add(new TypeInsnNode(ANEWARRAY, ClassUtil.CLASS_OBJECT));
int parameterOffset = 1;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
mn.maxStack += 3; mn.maxStack += 3;
il.add(new InsnNode(DUP)); il.add(new InsnNode(DUP));
@ -183,14 +229,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
} }
il.add(new InsnNode(AASTORE)); il.add(new InsnNode(AASTORE));
} }
if (isMockForConstructor(mn)) { return il;
il.add(new InsnNode(ICONST_1));
} else {
il.add(new InsnNode(ICONST_0));
}
il.add(new MethodInsnNode(INVOKESTATIC, CLASS_INVOKE_RECORD_UTIL, METHOD_RECORD_MOCK_INVOKE,
SIGNATURE_RECORDER_METHOD_INVOKE, false));
mn.instructions.insertBefore(mn.instructions.getFirst(), il);
} }
private boolean isMockForConstructor(MethodNode mn) { private boolean isMockForConstructor(MethodNode mn) {

View File

@ -195,7 +195,7 @@ public class SourceClassHandler extends BaseClassHandler {
} }
private int stackEffectOfInvocation(String desc) { private int stackEffectOfInvocation(String desc) {
return ClassUtil.getParameterTypes(desc).size() - (ClassUtil.getReturnType(desc).isEmpty() ? 0 : 1); return ClassUtil.getParameterTypes(desc).size() - (ClassUtil.getReturnType(desc).equals(VOID_RES) ? 0 : 1);
} }
private ModifiedInsnNodes replaceNewOps(ClassNode cn, MethodNode mn, String newOperatorInjectMethodName, private ModifiedInsnNodes replaceNewOps(ClassNode cn, MethodNode mn, String newOperatorInjectMethodName,

View File

@ -1,6 +1,7 @@
package com.alibaba.testable.agent.util; package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool; import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.tool.ImmutablePair;
import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassReader;
import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodInsnNode; import org.objectweb.asm.tree.MethodInsnNode;
@ -13,7 +14,7 @@ import java.util.Map;
import static com.alibaba.testable.core.constant.ConstPool.MOCK_POSTFIX; import static com.alibaba.testable.core.constant.ConstPool.MOCK_POSTFIX;
import static com.alibaba.testable.core.constant.ConstPool.TEST_POSTFIX; import static com.alibaba.testable.core.constant.ConstPool.TEST_POSTFIX;
import static org.objectweb.asm.Opcodes.INVOKESTATIC; import static org.objectweb.asm.Opcodes.*;
/** /**
* @author flin * @author flin
@ -45,8 +46,19 @@ public class ClassUtil {
private static final String CLASS_BOOLEAN = "java/lang/Boolean"; private static final String CLASS_BOOLEAN = "java/lang/Boolean";
private static final String EMPTY = ""; private static final String EMPTY = "";
private static final String METHOD_VALUE_OF = "valueOf"; private static final String METHOD_VALUE_OF = "valueOf";
private static final String METHOD_BYTE_VALUE = "byteValue";
private static final String METHOD_CHAR_VALUE = "charValue";
private static final String METHOD_DOUBLE_VALUE = "doubleValue";
private static final String METHOD_FLOAT_VALUE = "floatValue";
private static final String METHOD_INT_VALUE = "intValue";
private static final String METHOD_LONG_VALUE = "longValue";
private static final String METHOD_SHORT_VALUE = "shortValue";
private static final String METHOD_BOOLEAN_VALUE = "booleanValue";
private static final Map<Byte, String> TYPE_MAPPING = new HashMap<Byte, String>(); private static final Map<Byte, String> TYPE_MAPPING = new HashMap<Byte, String>();
private static final Map<Byte, ImmutablePair<String, String>> WRAPPER_METHOD_MAPPING =
new HashMap<Byte, ImmutablePair<String, String>>();
private static final Map<String, Integer> RETURN_OP_CODE_MAPPING = new HashMap<String, Integer>();
static { static {
TYPE_MAPPING.put(TYPE_BYTE, CLASS_BYTE); TYPE_MAPPING.put(TYPE_BYTE, CLASS_BYTE);
@ -60,6 +72,28 @@ public class ClassUtil {
TYPE_MAPPING.put(TYPE_VOID, EMPTY); TYPE_MAPPING.put(TYPE_VOID, EMPTY);
} }
static {
WRAPPER_METHOD_MAPPING.put(TYPE_BYTE, ImmutablePair.of(METHOD_BYTE_VALUE, "()" + (char)TYPE_BYTE));
WRAPPER_METHOD_MAPPING.put(TYPE_CHAR, ImmutablePair.of(METHOD_CHAR_VALUE, "()" + (char)TYPE_CHAR));
WRAPPER_METHOD_MAPPING.put(TYPE_DOUBLE, ImmutablePair.of(METHOD_DOUBLE_VALUE, "()" + (char)TYPE_DOUBLE));
WRAPPER_METHOD_MAPPING.put(TYPE_FLOAT, ImmutablePair.of(METHOD_FLOAT_VALUE, "()" + (char)TYPE_FLOAT));
WRAPPER_METHOD_MAPPING.put(TYPE_INT, ImmutablePair.of(METHOD_INT_VALUE, "()" + (char)TYPE_INT));
WRAPPER_METHOD_MAPPING.put(TYPE_LONG, ImmutablePair.of(METHOD_LONG_VALUE, "()" + (char)TYPE_LONG));
WRAPPER_METHOD_MAPPING.put(TYPE_SHORT, ImmutablePair.of(METHOD_SHORT_VALUE, "()" + (char)TYPE_SHORT));
WRAPPER_METHOD_MAPPING.put(TYPE_BOOL, ImmutablePair.of(METHOD_BOOLEAN_VALUE, "()" + (char)TYPE_BOOL));
}
static {
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_BYTE}), IRETURN);
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_CHAR}), IRETURN);
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_DOUBLE}), DRETURN);
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_FLOAT}), FRETURN);
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_INT}), IRETURN);
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_LONG}), LRETURN);
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_SHORT}), IRETURN);
RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_BOOL}), IRETURN);
}
/** /**
* fit kotlin companion class name to original name * fit kotlin companion class name to original name
* @param name a class name (which could be a companion class) * @param name a class name (which could be a companion class)
@ -166,20 +200,49 @@ public class ClassUtil {
*/ */
public static String getReturnType(String desc) { public static String getReturnType(String desc) {
int returnTypeEdge = desc.lastIndexOf(PARAM_END); int returnTypeEdge = desc.lastIndexOf(PARAM_END);
char typeChar = desc.charAt(returnTypeEdge + 1);
if (typeChar == TYPE_ARRAY) {
return desc.substring(returnTypeEdge + 1); return desc.substring(returnTypeEdge + 1);
} else if (typeChar == TYPE_CLASS) {
return desc.substring(returnTypeEdge + 2, desc.length() - 1);
} else if (TYPE_MAPPING.containsKey((byte)typeChar)) {
return TYPE_MAPPING.get((byte)typeChar);
} else {
return EMPTY;
}
} }
/** /**
* Get method node to convert primary type to object type * parse method desc, fetch first parameter type
* @param desc method description
* @return types of first parameter
*/
public static String getFirstParameter(String desc) {
int typeEdge = desc.indexOf(CLASS_END);
return desc.substring(1, typeEdge + 1);
}
/**
* get wrapper class of specified private type
* @param primaryType byte code of private type
* @return byte code of wrapper class
*/
public static String toWrapperClass(Byte primaryType) {
return TYPE_MAPPING.get(primaryType);
}
/**
* get method name and descriptor to convert wrapper type to primary type
* @param primaryType byte code of private type
* @return pair of <method-name, method-descriptor>
*/
public static ImmutablePair<String, String> getWrapperTypeConvertMethod(byte primaryType) {
return WRAPPER_METHOD_MAPPING.get(primaryType);
}
/**
* get byte code for return specified private type
* @param type class type
* @return byte code of return operation
*/
public static int getReturnOpsCode(String type) {
Integer code = RETURN_OP_CODE_MAPPING.get(type);
return (code == null) ? ARETURN : code;
}
/**
* get method node to convert primary type to wrapper type
* @param type primary type to convert * @param type primary type to convert
* @return converter method node * @return converter method node
*/ */
@ -216,6 +279,15 @@ public class ClassUtil {
return (char)TYPE_CLASS + toSlashSeparatedName(className) + (char)CLASS_END; return (char)TYPE_CLASS + toSlashSeparatedName(className) + (char)CLASS_END;
} }
/**
* convert byte code class name to slash separated human readable name
* @param className original name
* @return converted name
*/
public static String toSlashSeparateFullClassName(String className) {
return className.substring(1, className.length() - 1);
}
/** /**
* convert byte code class name to dot separated human readable name * convert byte code class name to dot separated human readable name
* @param className original name * @param className original name

View File

@ -23,13 +23,18 @@ class ClassUtilTest {
@Test @Test
void should_able_to_get_return_type() { void should_able_to_get_return_type() {
assertEquals("", ClassUtil.getReturnType("(Ljava/lang/String;)V")); assertEquals("V", ClassUtil.getReturnType("(Ljava/lang/String;)V"));
assertEquals("java/lang/Integer", ClassUtil.getReturnType("(Ljava/lang/String;)I")); assertEquals("I", ClassUtil.getReturnType("(Ljava/lang/String;)I"));
assertEquals("[I", ClassUtil.getReturnType("(Ljava/lang/String;)[I")); assertEquals("[I", ClassUtil.getReturnType("(Ljava/lang/String;)[I"));
assertEquals("java/lang/String", ClassUtil.getReturnType("(Ljava/lang/String;)Ljava/lang/String;")); assertEquals("Ljava/lang/String;", ClassUtil.getReturnType("(Ljava/lang/String;)Ljava/lang/String;"));
assertEquals("[Ljava/lang/String;", ClassUtil.getReturnType("(Ljava/lang/String;)[Ljava/lang/String;")); assertEquals("[Ljava/lang/String;", ClassUtil.getReturnType("(Ljava/lang/String;)[Ljava/lang/String;"));
} }
@Test
void should_able_to_get_first_parameter() {
assertEquals("Ljava/lang/String;", ClassUtil.getFirstParameter("(Ljava/lang/String;Ljava/lang/Object;I)V"));
}
@Test @Test
void should_able_to_convert_class_name() { void should_able_to_convert_class_name() {
assertEquals("Ljava/lang/String;", ClassUtil.toByteCodeClassName("java.lang.String")); assertEquals("Ljava/lang/String;", ClassUtil.toByteCodeClassName("java.lang.String"));

View File

@ -28,14 +28,19 @@ public class MockAssociationUtil {
*/ */
public static boolean isAssociated() { public static boolean isAssociated() {
MockContext mockContext = MockContextUtil.context.get(); MockContext mockContext = MockContextUtil.context.get();
String testClassName = (mockContext == null) ? "" : mockContext.testClassName; if (mockContext == null) {
// skip the association check
LogUtil.warn("Mock association check is invoked without test context");
return true;
}
String testClassName = mockContext.testClassName;
String mockClassName = Thread.currentThread().getStackTrace()[INDEX_OF_MOCK_CLASS].getClassName(); String mockClassName = Thread.currentThread().getStackTrace()[INDEX_OF_MOCK_CLASS].getClassName();
return isAssociatedByInnerMockClass(testClassName, mockClassName) || return isAssociatedByInnerMockClass(testClassName, mockClassName) ||
isAssociatedByOuterMockClass(testClassName, mockClassName) || isAssociatedByOuterMockClass(testClassName, mockClassName) ||
isAssociatedByMockWithAnnotation(testClassName, mockClassName); isAssociatedByMockWithAnnotation(testClassName, mockClassName);
} }
public static Object invokeOrigin(String originClass, String originMethod, Object originObj, Object... args) { public static Object invokeOrigin(Class<?> originClass, String originMethod, Object... args) {
return null; return null;
} }