diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java index e07c1e1..d8d4c8a 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/MockClassHandler.java @@ -11,7 +11,6 @@ import org.objectweb.asm.tree.*; import java.util.List; -import static com.alibaba.testable.agent.constant.ConstPool.CONSTRUCTOR; import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; /** @@ -23,7 +22,7 @@ public class MockClassHandler extends BaseClassWithContextHandler { private static final String CLASS_MOCK_ASSOCIATION_UTIL = "com/alibaba/testable/core/util/MockAssociationUtil"; private static final String METHOD_INVOKE_ORIGIN = "invokeOrigin"; private static final String SIGNATURE_INVOKE_ORIGIN = - "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/Object;"; + "(Ljava/lang/Class;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/Object;"; private static final String METHOD_RECORD_MOCK_INVOKE = "recordMockInvoke"; private static final String SIGNATURE_RECORDER_METHOD_INVOKE = "([Ljava/lang/Object;Z)V"; private static final String METHOD_IS_ASSOCIATED = "isAssociated"; @@ -61,7 +60,7 @@ public class MockClassHandler extends BaseClassWithContextHandler { il.add(new JumpInsnNode(IFNONNULL, label)); il.add(new TypeInsnNode(NEW, mockClassName)); il.add(new InsnNode(DUP)); - il.add(new MethodInsnNode(INVOKESPECIAL, mockClassName, CONSTRUCTOR, VOID_ARGS + VOID_RES, false)); + il.add(new MethodInsnNode(INVOKESPECIAL, mockClassName, ConstPool.CONSTRUCTOR, VOID_ARGS + VOID_RES, false)); il.add(new FieldInsnNode(PUTSTATIC, mockClassName, TESTABLE_REF, ClassUtil.toByteCodeClassName(mockClassName))); il.add(label); il.add(new FrameNode(F_SAME, 0, null, 0, null)); @@ -123,9 +122,52 @@ public class MockClassHandler extends BaseClassWithContextHandler { private InsnList invokeOriginalMethod(MethodNode mn) { InsnList il = new InsnList(); + mn.maxStack += 3; + ImmutablePair target = getTargetClassAndMethodName(mn); + il.add(new LdcInsnNode(target.left)); + il.add(new LdcInsnNode(target.right)); + il.add(duplicateParameters(mn)); + il.add(new MethodInsnNode(INVOKESTATIC, CLASS_MOCK_ASSOCIATION_UTIL, METHOD_INVOKE_ORIGIN, + SIGNATURE_INVOKE_ORIGIN, false)); + String returnType = ClassUtil.getReturnType(mn.desc); + if (VOID_RES.equals(returnType)) { + il.add(new InsnNode(POP)); + il.add(new InsnNode(RETURN)); + } else if (returnType.startsWith("[") || returnType.startsWith("L")) { + il.add(new TypeInsnNode(CHECKCAST, returnType)); + il.add(new InsnNode(ARETURN)); + } else { + String wrapperClass = ClassUtil.toWrapperClass(returnType.getBytes()[0]); + il.add(new TypeInsnNode(CHECKCAST, wrapperClass)); + ImmutablePair convertMethod = ClassUtil.getWrapperTypeConvertMethod(returnType.getBytes()[0]); + il.add(new MethodInsnNode(INVOKEVIRTUAL, wrapperClass, convertMethod.left, convertMethod.right, false)); + il.add(new InsnNode(ClassUtil.getReturnOpsCode(returnType))); + } return il; } + private ImmutablePair getTargetClassAndMethodName(MethodNode mn) { + Type className; + String methodName = mn.name; + for (AnnotationNode an : mn.visibleAnnotations) { + if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc)) { + String name = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_METHOD, + null, String.class); + if (name != null) { + methodName = name; + } + } else if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_CONSTRUCTOR).equals(an.desc)) { + methodName = ConstPool.CONSTRUCTOR; + } + } + if (methodName.equals(ConstPool.CONSTRUCTOR)) { + className = Type.getType(ClassUtil.getReturnType(mn.desc)); + } else { + className = Type.getType(ClassUtil.getFirstParameter(mn.desc)); + } + return ImmutablePair.of(className, methodName); + } + private boolean isGlobalScope(MethodNode mn) { for (AnnotationNode an : mn.visibleAnnotations) { if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc) || @@ -140,15 +182,6 @@ public class MockClassHandler extends BaseClassWithContextHandler { return false; } - private LabelNode getFirstLabel(MethodNode mn) { - for (AbstractInsnNode n : mn.instructions) { - if (n instanceof LabelNode) { - return (LabelNode)n; - } - } - return null; - } - private boolean isMockMethod(MethodNode mn) { if (mn.visibleAnnotations == null) { return false; @@ -163,13 +196,26 @@ public class MockClassHandler extends BaseClassWithContextHandler { } private void injectInvokeRecorder(MethodNode mn) { + InsnList il = new InsnList(); + mn.maxStack += 2; + il.add(duplicateParameters(mn)); + if (isMockForConstructor(mn)) { + il.add(new InsnNode(ICONST_1)); + } else { + il.add(new InsnNode(ICONST_0)); + } + il.add(new MethodInsnNode(INVOKESTATIC, CLASS_INVOKE_RECORD_UTIL, METHOD_RECORD_MOCK_INVOKE, + SIGNATURE_RECORDER_METHOD_INVOKE, false)); + mn.instructions.insertBefore(mn.instructions.getFirst(), il); + } + + private InsnList duplicateParameters(MethodNode mn) { InsnList il = new InsnList(); List types = ClassUtil.getParameterTypes(mn.desc); int size = types.size(); - int parameterOffset = 1; - mn.maxStack += 2; il.add(getIntInsn(size)); il.add(new TypeInsnNode(ANEWARRAY, ClassUtil.CLASS_OBJECT)); + int parameterOffset = 1; for (int i = 0; i < size; i++) { mn.maxStack += 3; il.add(new InsnNode(DUP)); @@ -183,14 +229,7 @@ public class MockClassHandler extends BaseClassWithContextHandler { } il.add(new InsnNode(AASTORE)); } - if (isMockForConstructor(mn)) { - il.add(new InsnNode(ICONST_1)); - } else { - il.add(new InsnNode(ICONST_0)); - } - il.add(new MethodInsnNode(INVOKESTATIC, CLASS_INVOKE_RECORD_UTIL, METHOD_RECORD_MOCK_INVOKE, - SIGNATURE_RECORDER_METHOD_INVOKE, false)); - mn.instructions.insertBefore(mn.instructions.getFirst(), il); + return il; } private boolean isMockForConstructor(MethodNode mn) { diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java index 5d62828..ceab231 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java @@ -195,7 +195,7 @@ public class SourceClassHandler extends BaseClassHandler { } private int stackEffectOfInvocation(String desc) { - return ClassUtil.getParameterTypes(desc).size() - (ClassUtil.getReturnType(desc).isEmpty() ? 0 : 1); + return ClassUtil.getParameterTypes(desc).size() - (ClassUtil.getReturnType(desc).equals(VOID_RES) ? 0 : 1); } private ModifiedInsnNodes replaceNewOps(ClassNode cn, MethodNode mn, String newOperatorInjectMethodName, diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java b/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java index 098abd1..c9cce27 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java @@ -1,6 +1,7 @@ package com.alibaba.testable.agent.util; import com.alibaba.testable.agent.constant.ConstPool; +import com.alibaba.testable.agent.tool.ImmutablePair; import org.objectweb.asm.ClassReader; import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.MethodInsnNode; @@ -13,7 +14,7 @@ import java.util.Map; import static com.alibaba.testable.core.constant.ConstPool.MOCK_POSTFIX; import static com.alibaba.testable.core.constant.ConstPool.TEST_POSTFIX; -import static org.objectweb.asm.Opcodes.INVOKESTATIC; +import static org.objectweb.asm.Opcodes.*; /** * @author flin @@ -45,8 +46,19 @@ public class ClassUtil { private static final String CLASS_BOOLEAN = "java/lang/Boolean"; private static final String EMPTY = ""; private static final String METHOD_VALUE_OF = "valueOf"; + private static final String METHOD_BYTE_VALUE = "byteValue"; + private static final String METHOD_CHAR_VALUE = "charValue"; + private static final String METHOD_DOUBLE_VALUE = "doubleValue"; + private static final String METHOD_FLOAT_VALUE = "floatValue"; + private static final String METHOD_INT_VALUE = "intValue"; + private static final String METHOD_LONG_VALUE = "longValue"; + private static final String METHOD_SHORT_VALUE = "shortValue"; + private static final String METHOD_BOOLEAN_VALUE = "booleanValue"; private static final Map TYPE_MAPPING = new HashMap(); + private static final Map> WRAPPER_METHOD_MAPPING = + new HashMap>(); + private static final Map RETURN_OP_CODE_MAPPING = new HashMap(); static { TYPE_MAPPING.put(TYPE_BYTE, CLASS_BYTE); @@ -60,6 +72,28 @@ public class ClassUtil { TYPE_MAPPING.put(TYPE_VOID, EMPTY); } + static { + WRAPPER_METHOD_MAPPING.put(TYPE_BYTE, ImmutablePair.of(METHOD_BYTE_VALUE, "()" + (char)TYPE_BYTE)); + WRAPPER_METHOD_MAPPING.put(TYPE_CHAR, ImmutablePair.of(METHOD_CHAR_VALUE, "()" + (char)TYPE_CHAR)); + WRAPPER_METHOD_MAPPING.put(TYPE_DOUBLE, ImmutablePair.of(METHOD_DOUBLE_VALUE, "()" + (char)TYPE_DOUBLE)); + WRAPPER_METHOD_MAPPING.put(TYPE_FLOAT, ImmutablePair.of(METHOD_FLOAT_VALUE, "()" + (char)TYPE_FLOAT)); + WRAPPER_METHOD_MAPPING.put(TYPE_INT, ImmutablePair.of(METHOD_INT_VALUE, "()" + (char)TYPE_INT)); + WRAPPER_METHOD_MAPPING.put(TYPE_LONG, ImmutablePair.of(METHOD_LONG_VALUE, "()" + (char)TYPE_LONG)); + WRAPPER_METHOD_MAPPING.put(TYPE_SHORT, ImmutablePair.of(METHOD_SHORT_VALUE, "()" + (char)TYPE_SHORT)); + WRAPPER_METHOD_MAPPING.put(TYPE_BOOL, ImmutablePair.of(METHOD_BOOLEAN_VALUE, "()" + (char)TYPE_BOOL)); + } + + static { + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_BYTE}), IRETURN); + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_CHAR}), IRETURN); + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_DOUBLE}), DRETURN); + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_FLOAT}), FRETURN); + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_INT}), IRETURN); + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_LONG}), LRETURN); + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_SHORT}), IRETURN); + RETURN_OP_CODE_MAPPING.put(new String(new byte[] {TYPE_BOOL}), IRETURN); + } + /** * fit kotlin companion class name to original name * @param name a class name (which could be a companion class) @@ -166,20 +200,49 @@ public class ClassUtil { */ public static String getReturnType(String desc) { int returnTypeEdge = desc.lastIndexOf(PARAM_END); - char typeChar = desc.charAt(returnTypeEdge + 1); - if (typeChar == TYPE_ARRAY) { - return desc.substring(returnTypeEdge + 1); - } else if (typeChar == TYPE_CLASS) { - return desc.substring(returnTypeEdge + 2, desc.length() - 1); - } else if (TYPE_MAPPING.containsKey((byte)typeChar)) { - return TYPE_MAPPING.get((byte)typeChar); - } else { - return EMPTY; - } + return desc.substring(returnTypeEdge + 1); } /** - * Get method node to convert primary type to object type + * parse method desc, fetch first parameter type + * @param desc method description + * @return types of first parameter + */ + public static String getFirstParameter(String desc) { + int typeEdge = desc.indexOf(CLASS_END); + return desc.substring(1, typeEdge + 1); + } + + /** + * get wrapper class of specified private type + * @param primaryType byte code of private type + * @return byte code of wrapper class + */ + public static String toWrapperClass(Byte primaryType) { + return TYPE_MAPPING.get(primaryType); + } + + /** + * get method name and descriptor to convert wrapper type to primary type + * @param primaryType byte code of private type + * @return pair of + */ + public static ImmutablePair getWrapperTypeConvertMethod(byte primaryType) { + return WRAPPER_METHOD_MAPPING.get(primaryType); + } + + /** + * get byte code for return specified private type + * @param type class type + * @return byte code of return operation + */ + public static int getReturnOpsCode(String type) { + Integer code = RETURN_OP_CODE_MAPPING.get(type); + return (code == null) ? ARETURN : code; + } + + /** + * get method node to convert primary type to wrapper type * @param type primary type to convert * @return converter method node */ @@ -216,6 +279,15 @@ public class ClassUtil { return (char)TYPE_CLASS + toSlashSeparatedName(className) + (char)CLASS_END; } + /** + * convert byte code class name to slash separated human readable name + * @param className original name + * @return converted name + */ + public static String toSlashSeparateFullClassName(String className) { + return className.substring(1, className.length() - 1); + } + /** * convert byte code class name to dot separated human readable name * @param className original name diff --git a/testable-agent/src/test/java/com/alibaba/testable/agent/util/ClassUtilTest.java b/testable-agent/src/test/java/com/alibaba/testable/agent/util/ClassUtilTest.java index 84883f0..2e58ce0 100644 --- a/testable-agent/src/test/java/com/alibaba/testable/agent/util/ClassUtilTest.java +++ b/testable-agent/src/test/java/com/alibaba/testable/agent/util/ClassUtilTest.java @@ -23,13 +23,18 @@ class ClassUtilTest { @Test void should_able_to_get_return_type() { - assertEquals("", ClassUtil.getReturnType("(Ljava/lang/String;)V")); - assertEquals("java/lang/Integer", ClassUtil.getReturnType("(Ljava/lang/String;)I")); + assertEquals("V", ClassUtil.getReturnType("(Ljava/lang/String;)V")); + assertEquals("I", ClassUtil.getReturnType("(Ljava/lang/String;)I")); assertEquals("[I", ClassUtil.getReturnType("(Ljava/lang/String;)[I")); - assertEquals("java/lang/String", ClassUtil.getReturnType("(Ljava/lang/String;)Ljava/lang/String;")); + assertEquals("Ljava/lang/String;", ClassUtil.getReturnType("(Ljava/lang/String;)Ljava/lang/String;")); assertEquals("[Ljava/lang/String;", ClassUtil.getReturnType("(Ljava/lang/String;)[Ljava/lang/String;")); } + @Test + void should_able_to_get_first_parameter() { + assertEquals("Ljava/lang/String;", ClassUtil.getFirstParameter("(Ljava/lang/String;Ljava/lang/Object;I)V")); + } + @Test void should_able_to_convert_class_name() { assertEquals("Ljava/lang/String;", ClassUtil.toByteCodeClassName("java.lang.String")); diff --git a/testable-core/src/main/java/com/alibaba/testable/core/util/MockAssociationUtil.java b/testable-core/src/main/java/com/alibaba/testable/core/util/MockAssociationUtil.java index 33cce63..1bcc9fd 100644 --- a/testable-core/src/main/java/com/alibaba/testable/core/util/MockAssociationUtil.java +++ b/testable-core/src/main/java/com/alibaba/testable/core/util/MockAssociationUtil.java @@ -28,14 +28,19 @@ public class MockAssociationUtil { */ public static boolean isAssociated() { MockContext mockContext = MockContextUtil.context.get(); - String testClassName = (mockContext == null) ? "" : mockContext.testClassName; + if (mockContext == null) { + // skip the association check + LogUtil.warn("Mock association check is invoked without test context"); + return true; + } + String testClassName = mockContext.testClassName; String mockClassName = Thread.currentThread().getStackTrace()[INDEX_OF_MOCK_CLASS].getClassName(); return isAssociatedByInnerMockClass(testClassName, mockClassName) || isAssociatedByOuterMockClass(testClassName, mockClassName) || isAssociatedByMockWithAnnotation(testClassName, mockClassName); } - public static Object invokeOrigin(String originClass, String originMethod, Object originObj, Object... args) { + public static Object invokeOrigin(Class originClass, String originMethod, Object... args) { return null; }