diff --git a/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java b/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java index 2b226c7..15ee99f 100644 --- a/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java +++ b/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java @@ -1,7 +1,6 @@ package com.alibaba.testable.handler; import com.alibaba.testable.util.ClassUtil; -import com.alibaba.testable.util.StringUtil; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassWriter; import org.objectweb.asm.Opcodes; @@ -9,8 +8,8 @@ import org.objectweb.asm.Type; import org.objectweb.asm.tree.*; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.HashSet; +import java.util.Set; import static com.alibaba.testable.constant.Const.SYS_CLASSES; @@ -39,7 +38,7 @@ public class TestableClassHandler implements Opcodes { } private void transform(ClassNode cn) { - List methodNames = new ArrayList(); + Set methodNames = new HashSet(); for (MethodNode m : cn.methods) { if (!CONSTRUCTOR.equals(m.name)) { methodNames.add(m.name); @@ -50,21 +49,24 @@ public class TestableClassHandler implements Opcodes { } } - private void transformMethod(ClassNode cn, MethodNode mn, List methodNames) { + private void transformMethod(ClassNode cn, MethodNode mn, Set methodNames) { AbstractInsnNode[] instructions = mn.instructions.toArray(); int i = 0; do { if (instructions[i].getOpcode() == Opcodes.INVOKESPECIAL) { MethodInsnNode node = (MethodInsnNode)instructions[i]; - int rangeEnd = i; if (cn.name.equals(node.owner) && methodNames.contains(node.name)) { - int rangeStart = getMemberMethodStart(instructions, rangeEnd); - instructions = replaceMemberCallOps(mn, instructions, rangeStart, rangeEnd); - i = rangeStart; + int rangeStart = getMemberMethodStart(instructions, i); + if (rangeStart >= 0) { + instructions = replaceMemberCallOps(mn, instructions, rangeStart, i); + i = rangeStart; + } } else if (CONSTRUCTOR.equals(node.name) && !SYS_CLASSES.contains(node.owner)) { - int rangeStart = getConstructorStart(instructions, node.owner, rangeEnd); - instructions = replaceNewOps(mn, instructions, rangeStart, rangeEnd); - i = rangeStart; + int rangeStart = getConstructorStart(instructions, node.owner, i); + if (rangeStart >= 0) { + instructions = replaceNewOps(mn, instructions, rangeStart, i); + i = rangeStart; + } } } i++; @@ -72,21 +74,21 @@ public class TestableClassHandler implements Opcodes { } private int getConstructorStart(AbstractInsnNode[] instructions, String target, int rangeEnd) { - for (int i = rangeEnd - 1; i > 0; i--) { + for (int i = rangeEnd - 1; i >= 0; i--) { if (instructions[i].getOpcode() == Opcodes.NEW && ((TypeInsnNode)instructions[i]).desc.equals(target)) { return i; } } - return 0; + return -1; } private int getMemberMethodStart(AbstractInsnNode[] instructions, int rangeEnd) { - for (int i = rangeEnd - 1; i > 0; i--) { + for (int i = rangeEnd - 1; i >= 0; i--) { if (instructions[i].getOpcode() == Opcodes.ALOAD && ((VarInsnNode)instructions[i]).var == 0) { return i; } } - return 0; + return -1; } private AbstractInsnNode[] replaceNewOps(MethodNode mn, AbstractInsnNode[] instructions, int start, int end) { @@ -107,7 +109,7 @@ public class TestableClassHandler implements Opcodes { private String getConstructorSubstitutionDesc(String constructorDesc) { int paramCount = ClassUtil.getParameterCount(constructorDesc); - return CONSTRUCTOR_DESC_PREFIX + StringUtil.repeat(OBJECT_DESC, paramCount) + METHOD_DESC_POSTFIX; + return CONSTRUCTOR_DESC_PREFIX + ClassUtil.repeat(OBJECT_DESC, paramCount) + METHOD_DESC_POSTFIX; } private AbstractInsnNode[] replaceMemberCallOps(MethodNode mn, AbstractInsnNode[] instructions, int start, int end) { @@ -127,7 +129,7 @@ public class TestableClassHandler implements Opcodes { private String getMethodSubstitutionDesc(String methodDesc) { int paramCount = ClassUtil.getParameterCount(methodDesc); - return METHOD_DESC_PREFIX + StringUtil.repeat(OBJECT_DESC, paramCount) + METHOD_DESC_POSTFIX; + return METHOD_DESC_PREFIX + ClassUtil.repeat(OBJECT_DESC, paramCount) + METHOD_DESC_POSTFIX; } } diff --git a/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java b/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java index f326946..299e80f 100644 --- a/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java +++ b/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java @@ -30,10 +30,10 @@ public class ClassUtil { } } - public static int getParameterCount(String paramTypes) { + public static int getParameterCount(String desc) { int paramCount = 0; boolean travelingClass = false; - for (byte b : paramTypes.getBytes()) { + for (byte b : desc.getBytes()) { if (travelingClass) { if (b == ';') { travelingClass = false; @@ -53,7 +53,26 @@ public class ClassUtil { } public static String getReturnType(String desc) { - return null; + int returnTypeEdge = desc.lastIndexOf(')'); + boolean isArrayType = false; + if (desc.charAt(returnTypeEdge + 1) == '[') { + isArrayType = true; + returnTypeEdge++; + } + switch (desc.charAt(returnTypeEdge + 1)) { + case 'L': + return desc.substring(returnTypeEdge + 2, desc.length() - 1); + default: + return ""; + } + } + + public static String repeat(String text, int times) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < times; i++) { + sb.append(text); + } + return sb.toString(); } } diff --git a/agent/src/main/java/com/alibaba/testable/util/StringUtil.java b/agent/src/main/java/com/alibaba/testable/util/StringUtil.java deleted file mode 100644 index 20f2f8d..0000000 --- a/agent/src/main/java/com/alibaba/testable/util/StringUtil.java +++ /dev/null @@ -1,16 +0,0 @@ -package com.alibaba.testable.util; - -/** - * @author flin - */ -public class StringUtil { - - public static String repeat(String text, int times) { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < times; i++) { - sb.append(text); - } - return sb.toString(); - } - -} diff --git a/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java b/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java index dd4936b..946e7f5 100644 --- a/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java +++ b/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java @@ -7,15 +7,17 @@ import static org.junit.jupiter.api.Assertions.*; class ClassUtilTest { @Test - void should_able_to_generate_target_desc() { - assertEquals("(Ljava/lang/Class;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.getParameterCount("(Ljava/lang/String;)V")); - assertEquals("(Ljava/lang/Class;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.getParameterCount("(Ljava/lang/String;IDLjava/lang/String;ZLjava/net/URL;)V")); - assertEquals("(Ljava/lang/Class;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.getParameterCount("(ZLjava/lang/String;IJFDCSBZ)V")); - assertEquals("(Ljava/lang/Class;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.getParameterCount("(Ljava/lang/String;[I[Ljava/lang/String;)V")); + void should_able_to_get_parameter_count() { + assertEquals(1, ClassUtil.getParameterCount("(Ljava/lang/String;)V")); + assertEquals(6, ClassUtil.getParameterCount("(Ljava/lang/String;IDLjava/lang/String;ZLjava/net/URL;)V")); + assertEquals(10, ClassUtil.getParameterCount("(ZLjava/lang/String;IJFDCSBZ)V")); + assertEquals(3, ClassUtil.getParameterCount("(Ljava/lang/String;[I[Ljava/lang/String;)V")); + } + + @Test + void should_able_to_get_return_type() { + assertEquals("", ClassUtil.getReturnType("(Ljava/lang/String;)V")); + assertEquals("java/lang/String", ClassUtil.getReturnType("(Ljava/lang/String;)Ljava/lang/String;")); } }