diff --git a/agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java b/agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java index 6ea2f43..359ff78 100644 --- a/agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java +++ b/agent/src/main/java/com/alibaba/testable/agent/handler/SourceClassHandler.java @@ -67,7 +67,11 @@ public class SourceClassHandler extends BaseClassHandler { // it's a member method of current class and an inject method for it exist int rangeStart = getMemberMethodStart(instructions, i); if (rangeStart >= 0) { - instructions = replaceMemberCallOps(cn, mn, instructions, rangeStart, i); + if (cn.name.equals(node.owner)) { + instructions = replaceMemberCallOps(cn, mn, instructions, rangeStart, i); + } else { + instructions = replaceCommonCallOps(cn, mn, instructions, node.owner, rangeStart, i); + } i = rangeStart; } } else if (ConstPool.CONSTRUCTOR.equals(node.name)) { @@ -134,7 +138,7 @@ public class SourceClassHandler extends BaseClassHandler { AbstractInsnNode[] instructions, int start, int end) { String classType = ((TypeInsnNode)instructions[start]).desc; String constructorDesc = ((MethodInsnNode)instructions[end]).desc; - String testClassName = StringUtil.getTestClassName(cn.name); + String testClassName = ClassUtil.getTestClassName(cn.name); mn.instructions.insertBefore(instructions[start], new FieldInsnNode(GETSTATIC, testClassName, ConstPool.TESTABLE_INJECT_REF, ClassUtil.toByteCodeClassName(testClassName))); mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, testClassName, @@ -153,7 +157,7 @@ public class SourceClassHandler extends BaseClassHandler { private AbstractInsnNode[] replaceMemberCallOps(ClassNode cn, MethodNode mn, AbstractInsnNode[] instructions, int start, int end) { MethodInsnNode method = (MethodInsnNode)instructions[end]; - String testClassName = StringUtil.getTestClassName(cn.name); + String testClassName = ClassUtil.getTestClassName(cn.name); mn.instructions.insertBefore(instructions[start], new FieldInsnNode(GETSTATIC, testClassName, ConstPool.TESTABLE_INJECT_REF, ClassUtil.toByteCodeClassName(testClassName))); mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, testClassName, @@ -163,4 +167,21 @@ public class SourceClassHandler extends BaseClassHandler { return mn.instructions.toArray(); } + private AbstractInsnNode[] replaceCommonCallOps(ClassNode cn, MethodNode mn, AbstractInsnNode[] instructions, + String ownerClass, int start, int end) { + mn.maxStack++; + MethodInsnNode method = (MethodInsnNode)instructions[end]; + String testClassName = ClassUtil.getTestClassName(cn.name); + mn.instructions.insertBefore(instructions[start], new FieldInsnNode(GETSTATIC, testClassName, + ConstPool.TESTABLE_INJECT_REF, ClassUtil.toByteCodeClassName(testClassName))); + mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, testClassName, + method.name, addFirstParameter(method.desc, ownerClass), false)); + mn.instructions.remove(instructions[end]); + return mn.instructions.toArray(); + } + + private String addFirstParameter(String desc, String ownerClass) { + return "(" + ClassUtil.toByteCodeClassName(ownerClass) + desc.substring(1); + } + } diff --git a/agent/src/main/java/com/alibaba/testable/agent/model/ImmutablePair.java b/agent/src/main/java/com/alibaba/testable/agent/model/ImmutablePair.java new file mode 100644 index 0000000..470171b --- /dev/null +++ b/agent/src/main/java/com/alibaba/testable/agent/model/ImmutablePair.java @@ -0,0 +1,22 @@ +package com.alibaba.testable.agent.model; + + +/** + * @author flin + */ +public class ImmutablePair { + + /** Left object */ + public final L left; + /** Right object */ + public final R right; + + public ImmutablePair(L left, R right) { + this.left = left; + this.right = right; + } + + public static ImmutablePair of(L l, R r) { + return new ImmutablePair(l, r); + } +} diff --git a/agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java b/agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java index 3cd6c82..c1b9df6 100644 --- a/agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java +++ b/agent/src/main/java/com/alibaba/testable/agent/transformer/TestableClassTransformer.java @@ -3,24 +3,34 @@ package com.alibaba.testable.agent.transformer; import com.alibaba.testable.agent.constant.ConstPool; import com.alibaba.testable.agent.handler.SourceClassHandler; import com.alibaba.testable.agent.handler.TestClassHandler; +import com.alibaba.testable.agent.model.ImmutablePair; import com.alibaba.testable.agent.model.MethodInfo; import com.alibaba.testable.agent.util.ClassUtil; -import com.alibaba.testable.agent.util.StringUtil; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.tree.AnnotationNode; +import org.objectweb.asm.tree.ClassNode; +import org.objectweb.asm.tree.MethodNode; import java.io.IOException; import java.lang.instrument.ClassFileTransformer; import java.net.URLClassLoader; import java.security.ProtectionDomain; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; +import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; +import static com.alibaba.testable.agent.util.ClassUtil.toSlashSeparatedName; + /** * @author flin */ public class TestableClassTransformer implements ClassFileTransformer { private final Set loadedClassNames = new HashSet(); + private static final String TARGET_CLASS = "targetClass"; + private static final String TARGET_METHOD = "targetMethod"; public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, ProtectionDomain protectionDomain, byte[] classFileBuffer) { @@ -30,12 +40,12 @@ public class TestableClassTransformer implements ClassFileTransformer { } List annotations = ClassUtil.getAnnotations(className); - List testAnnotations = ClassUtil.getAnnotations(StringUtil.getTestClassName(className)); + List testAnnotations = ClassUtil.getAnnotations(ClassUtil.getTestClassName(className)); try { if (testAnnotations.contains(ConstPool.ENABLE_TESTABLE)) { // it's a source class with testable enabled loadedClassNames.add(className); - List injectMethods = ClassUtil.getTestableInjectMethods(StringUtil.getTestClassName(className)); + List injectMethods = getTestableInjectMethods(ClassUtil.getTestClassName(className)); return new SourceClassHandler(injectMethods).getBytes(className); } else if (annotations.contains(ConstPool.ENABLE_TESTABLE)) { // it's a test class with testable enabled @@ -52,4 +62,57 @@ public class TestableClassTransformer implements ClassFileTransformer { return !(loader instanceof URLClassLoader) || null == className || className.startsWith("jdk/"); } + private List getTestableInjectMethods(String className) { + try { + List methodInfos = new ArrayList(); + ClassNode cn = new ClassNode(); + new ClassReader(className).accept(cn, 0); + for (MethodNode mn : cn.methods) { + checkMethodAnnotation(cn, methodInfos, mn); + } + return methodInfos; + } catch (Exception e) { + return new ArrayList(); + } + } + + private void checkMethodAnnotation(ClassNode cn, List methodInfos, MethodNode mn) { + if (mn.visibleAnnotations == null) { + return; + } + for (AnnotationNode an : mn.visibleAnnotations) { + if (toDotSeparateFullClassName(an.desc).equals(ConstPool.TESTABLE_INJECT)) { + String sourceClassName = ClassUtil.getSourceClassName(cn.name); + String targetClass = getAnnotationParameter(an, TARGET_CLASS, sourceClassName); + String targetMethod = getAnnotationParameter(an, TARGET_METHOD, mn.name); + if (sourceClassName.equals(targetClass)) { + methodInfos.add(new MethodInfo(toSlashSeparatedName(targetClass), targetMethod, mn.desc)); + } else { + ImmutablePair methodDescPair = extractFirstParameter(mn.desc); + if (methodDescPair != null && methodDescPair.left.equals(ClassUtil.toByteCodeClassName(targetClass))) { + methodInfos.add(new MethodInfo( + toSlashSeparatedName(targetClass), targetMethod, methodDescPair.right)); + } + } + break; + } + } + } + + private ImmutablePair extractFirstParameter(String desc) { + // assume first parameter is a class + int pos = desc.indexOf(";"); + return pos < 0 ? null : ImmutablePair.of(desc.substring(1, pos + 1), "(" + desc.substring(pos + 1)); + } + + private String getAnnotationParameter(AnnotationNode an, String key, String defaultValue) { + if (an.values != null) { + int i = an.values.indexOf(key); + if (i % 2 == 0) { + return (String)an.values.get(i+1); + } + } + return defaultValue; + } + } diff --git a/agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java b/agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java index fe22b9e..cf3e3eb 100644 --- a/agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java +++ b/agent/src/main/java/com/alibaba/testable/agent/util/ClassUtil.java @@ -1,11 +1,9 @@ package com.alibaba.testable.agent.util; import com.alibaba.testable.agent.constant.ConstPool; -import com.alibaba.testable.agent.model.MethodInfo; import org.objectweb.asm.ClassReader; import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.ClassNode; -import org.objectweb.asm.tree.MethodNode; import java.util.ArrayList; import java.util.HashMap; @@ -31,8 +29,6 @@ public class ClassUtil { private static final char TYPE_ARRAY = '['; private static final Map TYPE_MAPPING = new HashMap(); - private static final String TARGET_CLASS = "targetClass"; - private static final String TARGET_METHOD = "targetMethod"; static { TYPE_MAPPING.put(TYPE_BYTE, "java/lang/Byte"); @@ -64,47 +60,24 @@ public class ClassUtil { } /** - * Get testable inject method from test class - * @param className test class name + * get test class name from source class name + * @param sourceClassName source class name */ - public static List getTestableInjectMethods(String className) { - try { - List methodInfos = new ArrayList(); - ClassNode cn = new ClassNode(); - new ClassReader(className).accept(cn, 0); - for (MethodNode mn : cn.methods) { - checkMethodAnnotation(cn, methodInfos, mn); - } - return methodInfos; - } catch (Exception e) { - return new ArrayList(); - } + public static String getTestClassName(String sourceClassName) { + return sourceClassName + ConstPool.TEST_POSTFIX; } - private static void checkMethodAnnotation(ClassNode cn, List methodInfos, MethodNode mn) { - if (mn.visibleAnnotations == null) { - return; - } - for (AnnotationNode an : mn.visibleAnnotations) { - if (toDotSeparateFullClassName(an.desc).equals(ConstPool.TESTABLE_INJECT)) { - String targetClass = getAnnotationParameter(an, TARGET_CLASS, StringUtil.getSourceClassName(cn.name)); - String targetMethod = getAnnotationParameter(an, TARGET_METHOD, mn.name); - methodInfos.add(new MethodInfo(toSlashSeparateName(targetClass), targetMethod, mn.desc)); - break; - } - } - } - - private static String getAnnotationParameter(AnnotationNode an, String key, String defaultValue) { - if (an.values != null) { - int i = an.values.indexOf(key); - if (i % 2 == 0) { - return (String)an.values.get(i+1); - } - } - return defaultValue; + /** + * get source class name from test class name + * @param testClassName test class name + */ + public static String getSourceClassName(String testClassName) { + return testClassName.substring(0, testClassName.length() - ConstPool.TEST_POSTFIX.length()); } + /** + * parse method desc, fetch parameter types + */ public static List getParameterTypes(String desc) { List parameterTypes = new ArrayList(); boolean travelingClass = false; @@ -127,6 +100,9 @@ public class ClassUtil { return parameterTypes; } + /** + * parse method desc, fetch return value types + */ public static String getReturnType(String desc) { int returnTypeEdge = desc.lastIndexOf(PARAM_END); char typeChar = desc.charAt(returnTypeEdge + 1); @@ -141,14 +117,23 @@ public class ClassUtil { } } - private static String toSlashSeparateName(String name) { + /** + * convert dot separated name to slash separated name + */ + public static String toSlashSeparatedName(String name) { return name.replace(ConstPool.DOT, ConstPool.SLASH); } + /** + * convert dot separated name to byte code class name + */ public static String toByteCodeClassName(String className) { - return TYPE_CLASS + toSlashSeparateName(className) + CLASS_END; + return TYPE_CLASS + toSlashSeparatedName(className) + CLASS_END; } + /** + * convert byte code class name to dot separated human readable name + */ public static String toDotSeparateFullClassName(String className) { return className.replace(ConstPool.SLASH, ConstPool.DOT).substring(1, className.length() - 1); } diff --git a/agent/src/main/java/com/alibaba/testable/agent/util/StringUtil.java b/agent/src/main/java/com/alibaba/testable/agent/util/StringUtil.java index 1e95315..40e2f5a 100644 --- a/agent/src/main/java/com/alibaba/testable/agent/util/StringUtil.java +++ b/agent/src/main/java/com/alibaba/testable/agent/util/StringUtil.java @@ -1,7 +1,5 @@ package com.alibaba.testable.agent.util; -import com.alibaba.testable.agent.constant.ConstPool; - /** * @author flin */ @@ -20,20 +18,4 @@ public class StringUtil { return sb.toString(); } - /** - * get test class name from source class name - * @param sourceClassName source class name - */ - public static String getTestClassName(String sourceClassName) { - return sourceClassName + ConstPool.TEST_POSTFIX; - } - - /** - * get source class name from test class name - * @param testClassName test class name - */ - public static String getSourceClassName(String testClassName) { - return testClassName.substring(0, testClassName.length() - ConstPool.TEST_POSTFIX.length()); - } - }