diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java index 55e85ea..63bbe25 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java @@ -3,13 +3,13 @@ package com.alibaba.testable.agent.handler; import com.alibaba.testable.agent.constant.ConstPool; import com.alibaba.testable.agent.model.MethodInfo; import com.alibaba.testable.agent.model.ModifiedInsnNodes; +import com.alibaba.testable.agent.tool.ImmutablePair; import com.alibaba.testable.agent.util.BytecodeUtil; import com.alibaba.testable.agent.util.ClassUtil; import com.alibaba.testable.core.util.LogUtil; import org.objectweb.asm.Opcodes; import org.objectweb.asm.tree.*; -import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -60,18 +60,17 @@ public class SourceClassHandler extends BaseClassHandler { Set newOperatorInjectMethods) { LogUtil.diagnose(" Handling method %s", mn.name); AbstractInsnNode[] instructions = mn.instructions.toArray(); - List memberInjectMethodList = new ArrayList(memberInjectMethods); int i = 0; int maxStackDiff = 0; do { if (invokeOps.contains(instructions[i].getOpcode())) { MethodInsnNode node = (MethodInsnNode)instructions[i]; - String memberInjectMethodName = getMemberInjectMethodName(memberInjectMethodList, node); - if (memberInjectMethodName != null) { + ImmutablePair mockMethod = getMemberInjectMethodName(memberInjectMethods, node); + if (mockMethod != null) { // it's a member or static method and an inject method for it exist int rangeStart = getMemberMethodStart(instructions, i); if (rangeStart >= 0) { - ModifiedInsnNodes modifiedInsnNodes = replaceMemberCallOps(cn, mn, memberInjectMethodName, + ModifiedInsnNodes modifiedInsnNodes = replaceMemberCallOps(cn, mn, mockMethod, instructions, node.owner, node.getOpcode(), rangeStart, i); instructions = modifiedInsnNodes.nodes; maxStackDiff = Math.max(maxStackDiff, modifiedInsnNodes.stackDiff); @@ -100,11 +99,18 @@ public class SourceClassHandler extends BaseClassHandler { mn.maxStack += maxStackDiff; } - private String getMemberInjectMethodName(List memberInjectMethodList, MethodInsnNode node) { - for (MethodInfo m : memberInjectMethodList) { + /** + * find the mock method fit for specified method node + * @param memberInjectMethods mock methods available + * @param node method node to match for + * @return pair of + */ + private ImmutablePair getMemberInjectMethodName(Set memberInjectMethods, + MethodInsnNode node) { + for (MethodInfo m : memberInjectMethods) { String nodeOwner = ClassUtil.fitCompanionClassName(node.owner); if (m.getClazz().equals(nodeOwner) && m.getName().equals(node.name) && m.getDesc().equals(node.desc)) { - return m.getMockName(); + return ImmutablePair.of(m.getMockName(), m.getMockDesc()); } } return null; @@ -213,10 +219,10 @@ public class SourceClassHandler extends BaseClassHandler { ClassUtil.toByteCodeClassName(classType); } - private ModifiedInsnNodes replaceMemberCallOps(ClassNode cn, MethodNode mn, String substitutionMethod, + private ModifiedInsnNodes replaceMemberCallOps(ClassNode cn, MethodNode mn, ImmutablePair mockMethod, AbstractInsnNode[] instructions, String ownerClass, int opcode, int start, int end) { - LogUtil.diagnose(" Line %d, mock method %s used", getLineNum(instructions, start), substitutionMethod); + LogUtil.diagnose(" Line %d, mock method %s used", getLineNum(instructions, start), mockMethod.left); MethodInsnNode method = (MethodInsnNode)instructions[end]; String testClassName = ClassUtil.getTestClassName(cn.name); if (Opcodes.INVOKESTATIC == opcode || isCompanionMethod(ownerClass, opcode)) { @@ -229,7 +235,7 @@ public class SourceClassHandler extends BaseClassHandler { } // method with @MockMethod will be modified as public static access, so INVOKESTATIC is used mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKESTATIC, testClassName, - substitutionMethod, addFirstParameter(method.desc, ClassUtil.fitCompanionClassName(ownerClass)), false)); + mockMethod.left, mockMethod.right, false)); mn.instructions.remove(instructions[end]); return new ModifiedInsnNodes(mn.instructions.toArray(), 1); } @@ -238,8 +244,4 @@ public class SourceClassHandler extends BaseClassHandler { return Opcodes.INVOKEVIRTUAL == opcode && ClassUtil.isCompanionClassName(ownerClass); } - private String addFirstParameter(String desc, String ownerClass) { - return "(" + ClassUtil.toByteCodeClassName(ownerClass) + desc.substring(1); - } - } diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/model/MethodInfo.java b/testable-agent/src/main/java/com/alibaba/testable/agent/model/MethodInfo.java index 7ef5ed6..1bacd3b 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/model/MethodInfo.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/model/MethodInfo.java @@ -13,20 +13,25 @@ public class MethodInfo { * name of the source method */ private final String name; + /** + * parameter and return value of the source method + */ + private final String desc; /** * name of the mock method */ private final String mockName; /** - * parameter and return value of the source method + * parameter and return value of the mock method */ - private final String desc; + private final String mockDesc; - public MethodInfo(String clazz, String name, String mockName, String desc) { + public MethodInfo(String clazz, String name, String desc, String mockName, String mockDesc) { this.clazz = clazz; this.name = name; - this.mockName = mockName; this.desc = desc; + this.mockName = mockName; + this.mockDesc = mockDesc; } public String getClazz() { @@ -37,12 +42,16 @@ public class MethodInfo { return name; } + public String getDesc() { + return desc; + } + public String getMockName() { return mockName; } - public String getDesc() { - return desc; + public String getMockDesc() { + return mockDesc; } @Override @@ -54,16 +63,18 @@ public class MethodInfo { if (!clazz.equals(that.clazz)) { return false; } if (!name.equals(that.name)) { return false; } + if (!desc.equals(that.desc)) { return false; } if (!mockName.equals(that.mockName)) { return false; } - return desc.equals(that.desc); + return mockDesc.equals(that.mockDesc); } @Override public int hashCode() { int result = clazz.hashCode(); result = 31 * result + name.hashCode(); - result = 31 * result + mockName.hashCode(); result = 31 * result + desc.hashCode(); + result = 31 * result + mockName.hashCode(); + result = 31 * result + mockDesc.hashCode(); return result; } } diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java b/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java index e81fbb0..e74a2e7 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java @@ -3,16 +3,17 @@ package com.alibaba.testable.agent.transformer; import com.alibaba.testable.agent.constant.ConstPool; import com.alibaba.testable.agent.handler.SourceClassHandler; import com.alibaba.testable.agent.handler.TestClassHandler; -import com.alibaba.testable.agent.tool.ImmutablePair; import com.alibaba.testable.agent.model.MethodInfo; +import com.alibaba.testable.agent.tool.ImmutablePair; import com.alibaba.testable.agent.util.AnnotationUtil; import com.alibaba.testable.agent.util.ClassUtil; import com.alibaba.testable.agent.util.GlobalConfig; import com.alibaba.testable.agent.util.StringUtil; +import com.alibaba.testable.core.model.MockDiagnose; import com.alibaba.testable.core.model.NullType; import com.alibaba.testable.core.util.LogUtil; -import com.alibaba.testable.core.model.MockDiagnose; import org.objectweb.asm.ClassReader; +import org.objectweb.asm.Type; import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.MethodNode; @@ -22,7 +23,8 @@ import java.io.FileOutputStream; import java.io.IOException; import java.lang.instrument.ClassFileTransformer; import java.security.ProtectionDomain; -import java.util.*; +import java.util.ArrayList; +import java.util.List; import static com.alibaba.testable.agent.constant.ConstPool.DOT; import static com.alibaba.testable.agent.constant.ConstPool.SLASH; @@ -138,44 +140,43 @@ public class TestableClassTransformer implements ClassFileTransformer { for (AnnotationNode an : mn.visibleAnnotations) { String fullClassName = toDotSeparateFullClassName(an.desc); if (fullClassName.equals(ConstPool.MOCK_CONSTRUCTOR)) { - addMockConstructor(cn, methodInfos, mn); + addMockConstructor(methodInfos, cn, mn); } else if (fullClassName.equals(ConstPool.MOCK_METHOD) || fullClassName.equals(ConstPool.TESTABLE_MOCK)) { - ImmutablePair methodDescPair = getMethodDescPair(mn, an); - if (methodDescPair == null) { - return; - } String targetMethod = AnnotationUtil.getAnnotationParameter( an, ConstPool.FIELD_TARGET_METHOD, mn.name, String.class); - if (targetMethod.equals(ConstPool.CONSTRUCTOR)) { - addMockConstructor(cn, methodInfos, mn); + if (ConstPool.CONSTRUCTOR.equals(targetMethod)) { + addMockConstructor(methodInfos, cn, mn); } else { - addMockMethod(methodInfos, mn, methodDescPair, targetMethod); + MethodInfo mi = getMethodInfo(mn, an, targetMethod); + if (mi != null) { + methodInfos.add(mi); + } } break; } } } - private ImmutablePair getMethodDescPair(MethodNode mn, AnnotationNode an) { - Class targetClass = AnnotationUtil.getAnnotationParameter( - an, ConstPool.FIELD_TARGET_CLASS, NullType.class, Class.class); - if (targetClass.equals(NullType.class)) { - return extractFirstParameter(mn.desc); + private MethodInfo getMethodInfo(MethodNode mn, AnnotationNode an, String targetMethod) { + Type targetType = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class); + if (targetType == null || targetType.getClassName().equals(NullType.class.getName())) { + // "targetClass" unset, use first parameter as target class type + ImmutablePair methodDescPair = extractFirstParameter(mn.desc); + if (methodDescPair == null) { + return null; + } + return new MethodInfo(methodDescPair.left, targetMethod, methodDescPair.right, mn.name, mn.desc); } else { - return ImmutablePair.of(ClassUtil.toByteCodeClassName(targetClass.getName()), mn.desc); + // "targetClass" found, use it as target class type + String slashSeparatedName = ClassUtil.toSlashSeparatedName(targetType.getClassName()); + return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name, mn.desc); } } - private void addMockMethod(List methodInfos, MethodNode mn, - ImmutablePair methodDescPair, String targetMethod) { - String targetClass = ClassUtil.toSlashSeparateFullClassName(methodDescPair.left); - methodInfos.add(new MethodInfo(targetClass, targetMethod, mn.name, methodDescPair.right)); - } - - private void addMockConstructor(ClassNode cn, List methodInfos, MethodNode mn) { + private void addMockConstructor(List methodInfos, ClassNode cn, MethodNode mn) { String sourceClassName = ClassUtil.getSourceClassName(cn.name); - methodInfos.add(new MethodInfo(sourceClassName, ConstPool.CONSTRUCTOR, mn.name, mn.desc)); + methodInfos.add(new MethodInfo(sourceClassName, ConstPool.CONSTRUCTOR, mn.desc, mn.name, mn.desc)); } /** @@ -232,7 +233,7 @@ public class TestableClassTransformer implements ClassFileTransformer { private ImmutablePair extractFirstParameter(String desc) { // assume first parameter is a class int pos = desc.indexOf(";"); - return pos < 0 ? null : ImmutablePair.of(desc.substring(1, pos + 1), "(" + desc.substring(pos + 1)); + return pos < 0 ? null : ImmutablePair.of(desc.substring(2, pos), "(" + desc.substring(pos + 1)); } } diff --git a/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java b/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java index 7c650cf..fbdc498 100644 --- a/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java +++ b/testable-agent/src/main/java/com/alibaba/testable/agent/util/AnnotationUtil.java @@ -29,7 +29,11 @@ public class AnnotationUtil { Class enumClazz = (Class)clazz; return (T)Enum.valueOf(enumClazz, values[1]); } - return clazz.cast(an.values.get(i + 1)); + try { + return clazz.cast(an.values.get(i + 1)); + } catch (ClassCastException e) { + return defaultValue; + } } } }