diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/TestClassHandler.java b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/TestClassHandler.java index 0dc55a0..a79d2f0 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/TestClassHandler.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/TestClassHandler.java @@ -1,6 +1,7 @@ package com.alibaba.testable.agent.handler; import com.alibaba.testable.agent.constant.ConstPool; +import com.alibaba.testable.agent.tool.ImmutablePair; import com.alibaba.testable.agent.util.ClassUtil; import org.objectweb.asm.tree.*; @@ -21,7 +22,7 @@ public class TestClassHandler extends BaseClassHandler { private static final String FIELD_SOURCE_METHOD = "SOURCE_METHOD"; private static final String METHOD_CURRENT_TEST_CASE_NAME = "currentTestCaseName"; private static final String METHOD_CURRENT_SOURCE_METHOD_NAME = "currentSourceMethodName"; - private static final String METHOD_COUNT_MOCK_INVOKE = "countMockInvoke"; + private static final String METHOD_RECORD_MOCK_INVOKE = "recordMockInvoke"; private static final String SIGNATURE_TESTABLE_UTIL_METHOD = "(Ljava/lang/Object;)Ljava/lang/String;"; private static final String SIGNATURE_INVOKE_COUNTER_METHOD = "()V"; private static final Map FIELD_TO_METHOD_MAPPING = new HashMap() {{ @@ -59,7 +60,7 @@ public class TestClassHandler extends BaseClassHandler { mn.access &= ~ACC_PRIVATE; mn.access &= ~ACC_PROTECTED; mn.access |= ACC_PUBLIC; - injectInvokeCounter(mn); + injectInvokeRecorder(mn); } else if (couldBeTestMethod(mn)) { injectTestableRef(cn, mn); } @@ -95,10 +96,66 @@ public class TestClassHandler extends BaseClassHandler { return mn.instructions.toArray(); } - private void injectInvokeCounter(MethodNode mn) { - MethodInsnNode node = new MethodInsnNode(INVOKESTATIC, CLASS_INVOKE_RECORD_UTIL, METHOD_COUNT_MOCK_INVOKE, - SIGNATURE_INVOKE_COUNTER_METHOD, false); - mn.instructions.insertBefore(mn.instructions.get(0), node); + private void injectInvokeRecorder(MethodNode mn) { + InsnList il = new InsnList(); + List types = ClassUtil.getParameterTypes(mn.desc); + int size = mn.parameters.size(); + int parameterOffset = 1; + il.add(getIntInsn(size)); + il.add(new TypeInsnNode(ANEWARRAY, ClassUtil.CLASS_OBJECT)); + for (int i = 0; i < size; i++) { + il.add(new InsnNode(DUP)); + il.add(getIntInsn(i)); + ImmutablePair code = getLoadParameterByteCode(types.get(i)); + il.add(new VarInsnNode(code.left, parameterOffset)); + parameterOffset += code.right; + MethodInsnNode typeConvertMethodNode = ClassUtil.getPrimaryTypeConvertMethod(types.get(i)); + if (typeConvertMethodNode != null) { + il.add(typeConvertMethodNode); + } + il.add(new InsnNode(AASTORE)); + } + il.add(new MethodInsnNode(INVOKESTATIC, CLASS_INVOKE_RECORD_UTIL, METHOD_RECORD_MOCK_INVOKE, + SIGNATURE_INVOKE_COUNTER_METHOD, false)); + mn.instructions.insertBefore(mn.instructions.get(0), il); + } + + private static ImmutablePair getLoadParameterByteCode(Byte type) { + switch (type) { + case ClassUtil.TYPE_BYTE: + case ClassUtil.TYPE_CHAR: + case ClassUtil.TYPE_SHORT: + case ClassUtil.TYPE_INT: + case ClassUtil.TYPE_BOOL: + return ImmutablePair.of(ILOAD, 1); + case ClassUtil.TYPE_DOUBLE: + return ImmutablePair.of(DLOAD, 2); + case ClassUtil.TYPE_FLOAT: + return ImmutablePair.of(FLOAD, 1); + case ClassUtil.TYPE_LONG: + return ImmutablePair.of(LLOAD, 2); + default: + return ImmutablePair.of(ALOAD, 1); + } + } + + private AbstractInsnNode getIntInsn(int num) { + switch (num) { + case 0: + return new InsnNode(ICONST_0); + case 1: + return new InsnNode(ICONST_1); + case 2: + return new InsnNode(ICONST_2); + case 3: + return new InsnNode(ICONST_3); + case 4: + return new InsnNode(ICONST_4); + case 5: + return new InsnNode(ICONST_5); + default: + return new IntInsnNode(BIPUSH, num); + } } private void injectTestableRef(ClassNode cn, MethodNode mn) { diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java b/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java index f3bf44e..dd66d1d 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java @@ -5,41 +5,55 @@ import com.alibaba.testable.agent.tool.ComparableWeakRef; import org.objectweb.asm.ClassReader; import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.ClassNode; +import org.objectweb.asm.tree.MethodInsnNode; import org.objectweb.asm.tree.MethodNode; import java.util.*; +import static org.objectweb.asm.Opcodes.INVOKESTATIC; + /** * @author flin */ public class ClassUtil { - private static final char TYPE_BYTE = 'B'; - private static final char TYPE_CHAR = 'C'; - private static final char TYPE_DOUBLE = 'D'; - private static final char TYPE_FLOAT = 'F'; - private static final char TYPE_INT = 'I'; - private static final char TYPE_LONG = 'J'; - private static final char TYPE_CLASS = 'L'; - private static final char TYPE_SHORT = 'S'; - private static final char TYPE_BOOL = 'Z'; - private static final char PARAM_END = ')'; - private static final char CLASS_END = ';'; - private static final char TYPE_ARRAY = '['; + public static final byte TYPE_BYTE = 'B'; + public static final byte TYPE_CHAR = 'C'; + public static final byte TYPE_DOUBLE = 'D'; + public static final byte TYPE_FLOAT = 'F'; + public static final byte TYPE_INT = 'I'; + public static final byte TYPE_LONG = 'J'; + public static final byte TYPE_CLASS = 'L'; + public static final byte TYPE_SHORT = 'S'; + public static final byte TYPE_BOOL = 'Z'; + private static final byte PARAM_END = ')'; + private static final byte CLASS_END = ';'; + private static final byte TYPE_ARRAY = '['; - private static final Map TYPE_MAPPING = new HashMap(); + public static final String CLASS_OBJECT = "java/lang/Object"; + private static final String CLASS_BYTE = "java/lang/Byte"; + private static final String CLASS_CHARACTER = "java/lang/Character"; + private static final String CLASS_DOUBLE = "java/lang/Double"; + private static final String CLASS_FLOAT = "java/lang/Float"; + private static final String CLASS_INTEGER = "java/lang/Integer"; + private static final String CLASS_LONG = "java/lang/Long"; + private static final String CLASS_SHORT = "java/lang/Short"; + private static final String CLASS_BOOLEAN = "java/lang/Boolean"; + private static final String METHOD_VALUE_OF = "valueOf"; + + private static final Map TYPE_MAPPING = new HashMap(); private static final Map, Boolean> loadedClass = new WeakHashMap, Boolean>(); static { - TYPE_MAPPING.put(TYPE_BYTE, "java/lang/Byte"); - TYPE_MAPPING.put(TYPE_CHAR, "java/lang/Character"); - TYPE_MAPPING.put(TYPE_DOUBLE, "java/lang/Double"); - TYPE_MAPPING.put(TYPE_FLOAT, "java/lang/Float"); - TYPE_MAPPING.put(TYPE_INT, "java/lang/Integer"); - TYPE_MAPPING.put(TYPE_LONG, "java/lang/Long"); - TYPE_MAPPING.put(TYPE_SHORT, "java/lang/Short"); - TYPE_MAPPING.put(TYPE_BOOL, "java/lang/Boolean"); + TYPE_MAPPING.put(TYPE_BYTE, CLASS_BYTE); + TYPE_MAPPING.put(TYPE_CHAR, CLASS_CHARACTER); + TYPE_MAPPING.put(TYPE_DOUBLE, CLASS_DOUBLE); + TYPE_MAPPING.put(TYPE_FLOAT, CLASS_FLOAT); + TYPE_MAPPING.put(TYPE_INT, CLASS_INTEGER); + TYPE_MAPPING.put(TYPE_LONG, CLASS_LONG); + TYPE_MAPPING.put(TYPE_SHORT, CLASS_SHORT); + TYPE_MAPPING.put(TYPE_BOOL, CLASS_BOOLEAN); } /** @@ -139,13 +153,27 @@ public class ClassUtil { return desc.substring(returnTypeEdge + 1); } else if (typeChar == TYPE_CLASS) { return desc.substring(returnTypeEdge + 2, desc.length() - 1); - } else if (TYPE_MAPPING.containsKey(typeChar)) { - return TYPE_MAPPING.get(typeChar); + } else if (TYPE_MAPPING.containsKey((byte)typeChar)) { + return TYPE_MAPPING.get((byte)typeChar); } else { return ""; } } + /** + * Get method node to convert primary type to object type + * @param type primary type to convert + */ + public static MethodInsnNode getPrimaryTypeConvertMethod(Byte type) { + String objectType = TYPE_MAPPING.get(type); + return (objectType == null) ? null : + new MethodInsnNode(INVOKESTATIC, objectType, METHOD_VALUE_OF, toDescriptor(type, objectType), false); + } + + private static String toDescriptor(Byte type, String objectType) { + return "(" + (char)type.byteValue() + ")L" + objectType + ";"; + } + /** * convert slash separated name to dot separated name */ @@ -164,7 +192,7 @@ public class ClassUtil { * convert dot separated name to byte code class name */ public static String toByteCodeClassName(String className) { - return TYPE_CLASS + toSlashSeparatedName(className) + CLASS_END; + return (char)TYPE_CLASS + toSlashSeparatedName(className) + (char)CLASS_END; } /** @@ -185,5 +213,4 @@ public class ClassUtil { return b == TYPE_BYTE || b == TYPE_CHAR || b == TYPE_DOUBLE || b == TYPE_FLOAT || b == TYPE_INT || b == TYPE_LONG || b == TYPE_SHORT || b == TYPE_BOOL; } - } diff --git a/testable-core/src/main/java/com/alibaba/testable/core/util/InvokeRecordUtil.java b/testable-core/src/main/java/com/alibaba/testable/core/util/InvokeRecordUtil.java index 38e4e81..3f017d2 100644 --- a/testable-core/src/main/java/com/alibaba/testable/core/util/InvokeRecordUtil.java +++ b/testable-core/src/main/java/com/alibaba/testable/core/util/InvokeRecordUtil.java @@ -17,7 +17,7 @@ public class InvokeRecordUtil { public static final int INDEX_OF_TEST_CLASS = 2; /** - * Record mock method invoke event + * Record mock method invoke count */ public static void countMockInvoke() { StackTraceElement mockMethodTraceElement = Thread.currentThread().getStackTrace()[INDEX_OF_TEST_CLASS]; @@ -29,13 +29,20 @@ public class InvokeRecordUtil { INVOKE_RECORDS.put(key, count + 1); } + /** + * Record mock method invoke event + */ + public static void recordMockInvoke(Object[] args) { + countMockInvoke(); + } + + /** + * Get mock method invoke count + */ public static int getInvokeCount(String mockMethodName, String testCaseName) { String key = testCaseName + JOINER + mockMethodName; Integer count = INVOKE_RECORDS.get(key); - if (count == null) { - count = 0; - } - return count; + return (count == null) ? 0 : count; } }