diff --git a/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java b/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java index bf99b9c..2b226c7 100644 --- a/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java +++ b/agent/src/main/java/com/alibaba/testable/handler/TestableClassHandler.java @@ -1,7 +1,7 @@ package com.alibaba.testable.handler; -import com.alibaba.testable.model.TravelStatus; import com.alibaba.testable.util.ClassUtil; +import com.alibaba.testable.util.StringUtil; import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassWriter; import org.objectweb.asm.Opcodes; @@ -21,6 +21,12 @@ public class TestableClassHandler implements Opcodes { private static final String CONSTRUCTOR = ""; private static final String TESTABLE_NE = "n/e"; + private static final String TESTABLE_W = "w"; + private static final String TESTABLE_F = "f"; + private static final String CONSTRUCTOR_DESC_PREFIX = "(Ljava/lang/Class;"; + private static final String METHOD_DESC_PREFIX = "(Ljava/lang/Object;Ljava/lang/String;"; + private static final String OBJECT_DESC = "Ljava/lang/Object;"; + private static final String METHOD_DESC_POSTFIX = ")Ljava/lang/Object;"; public byte[] getBytes(String className) throws IOException { ClassReader cr = new ClassReader(className); @@ -46,38 +52,50 @@ public class TestableClassHandler implements Opcodes { private void transformMethod(ClassNode cn, MethodNode mn, List methodNames) { AbstractInsnNode[] instructions = mn.instructions.toArray(); - TravelStatus status = TravelStatus.INIT; - String target = ""; - int rangeStart = 0; int i = 0; do { - if (instructions[i].getOpcode() == Opcodes.NEW) { - TypeInsnNode node = (TypeInsnNode)instructions[i]; - if (!SYS_CLASSES.contains(node.desc)) { - target = node.desc; - status = TravelStatus.NEW_REP; - rangeStart = i; - } - } else if (instructions[i].getOpcode() == Opcodes.INVOKESPECIAL) { + if (instructions[i].getOpcode() == Opcodes.INVOKESPECIAL) { MethodInsnNode node = (MethodInsnNode)instructions[i]; - if (methodNames.contains(node.name) && cn.name.equals(node.owner)) { - status = TravelStatus.MEM_REP; - } else if (TravelStatus.NEW_REP == status && CONSTRUCTOR.equals(node.name) && target.equals(node.owner)) { - instructions = replaceNewOps(mn, instructions, rangeStart, i); + int rangeEnd = i; + if (cn.name.equals(node.owner) && methodNames.contains(node.name)) { + int rangeStart = getMemberMethodStart(instructions, rangeEnd); + instructions = replaceMemberCallOps(mn, instructions, rangeStart, rangeEnd); + i = rangeStart; + } else if (CONSTRUCTOR.equals(node.name) && !SYS_CLASSES.contains(node.owner)) { + int rangeStart = getConstructorStart(instructions, node.owner, rangeEnd); + instructions = replaceNewOps(mn, instructions, rangeStart, rangeEnd); i = rangeStart; - status = TravelStatus.INIT; } } i++; } while (i < instructions.length); } + private int getConstructorStart(AbstractInsnNode[] instructions, String target, int rangeEnd) { + for (int i = rangeEnd - 1; i > 0; i--) { + if (instructions[i].getOpcode() == Opcodes.NEW && ((TypeInsnNode)instructions[i]).desc.equals(target)) { + return i; + } + } + return 0; + } + + private int getMemberMethodStart(AbstractInsnNode[] instructions, int rangeEnd) { + for (int i = rangeEnd - 1; i > 0; i--) { + if (instructions[i].getOpcode() == Opcodes.ALOAD && ((VarInsnNode)instructions[i]).var == 0) { + return i; + } + } + return 0; + } + private AbstractInsnNode[] replaceNewOps(MethodNode mn, AbstractInsnNode[] instructions, int start, int end) { String classType = ((TypeInsnNode)instructions[start]).desc; - String paramTypes = ((MethodInsnNode)instructions[end]).desc; + String constructorDesc = ((MethodInsnNode)instructions[end]).desc; mn.instructions.insertBefore(instructions[start], new LdcInsnNode(Type.getType("L" + classType + ";"))); InsnList il = new InsnList(); - il.add(new MethodInsnNode(INVOKESTATIC, TESTABLE_NE, "w", ClassUtil.generateTargetDesc(paramTypes), false)); + il.add(new MethodInsnNode(INVOKESTATIC, TESTABLE_NE, TESTABLE_W, + getConstructorSubstitutionDesc(constructorDesc), false)); il.add(new TypeInsnNode(CHECKCAST, classType)); mn.instructions.insertBefore(instructions[end], il); mn.instructions.remove(instructions[start]); @@ -87,4 +105,29 @@ public class TestableClassHandler implements Opcodes { return mn.instructions.toArray(); } + private String getConstructorSubstitutionDesc(String constructorDesc) { + int paramCount = ClassUtil.getParameterCount(constructorDesc); + return CONSTRUCTOR_DESC_PREFIX + StringUtil.repeat(OBJECT_DESC, paramCount) + METHOD_DESC_POSTFIX; + } + + private AbstractInsnNode[] replaceMemberCallOps(MethodNode mn, AbstractInsnNode[] instructions, int start, int end) { + String methodDesc = ((MethodInsnNode)instructions[end]).desc; + String returnType = ClassUtil.getReturnType(methodDesc); + String methodName = ((MethodInsnNode)instructions[end]).name; + mn.instructions.insert(instructions[start], new LdcInsnNode(methodName)); + InsnList il = new InsnList(); + il.add(new MethodInsnNode(INVOKESTATIC, TESTABLE_NE, TESTABLE_F, + getMethodSubstitutionDesc(methodDesc), false)); + il.add(new TypeInsnNode(CHECKCAST, returnType)); + mn.instructions.insertBefore(instructions[end], il); + mn.instructions.remove(instructions[end]); + mn.maxStack += 1; + return mn.instructions.toArray(); + } + + private String getMethodSubstitutionDesc(String methodDesc) { + int paramCount = ClassUtil.getParameterCount(methodDesc); + return METHOD_DESC_PREFIX + StringUtil.repeat(OBJECT_DESC, paramCount) + METHOD_DESC_POSTFIX; + } + } diff --git a/agent/src/main/java/com/alibaba/testable/model/TravelStatus.java b/agent/src/main/java/com/alibaba/testable/model/TravelStatus.java deleted file mode 100644 index 2f55ddb..0000000 --- a/agent/src/main/java/com/alibaba/testable/model/TravelStatus.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.alibaba.testable.model; - -/** - * @author flin - */ - -public enum TravelStatus { - - /** - * Initialized - */ - INIT, - /** - * Processing member method replacement - */ - MEM_REP, - /** - * Processing member operation replacement - */ - NEW_REP - -} diff --git a/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java b/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java index 0ce178c..f326946 100644 --- a/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java +++ b/agent/src/main/java/com/alibaba/testable/util/ClassUtil.java @@ -30,7 +30,7 @@ public class ClassUtil { } } - public static String generateTargetDesc(String paramTypes) { + public static int getParameterCount(String paramTypes) { int paramCount = 0; boolean travelingClass = false; for (byte b : paramTypes.getBytes()) { @@ -49,15 +49,11 @@ public class ClassUtil { } } } - return "(Ljava/lang/Class;" + repeat("Ljava/lang/Object;", paramCount) + ")Ljava/lang/Object;"; + return paramCount; } - private static String repeat(String text, int times) { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < times; i++) { - sb.append(text); - } - return sb.toString(); + public static String getReturnType(String desc) { + return null; } } diff --git a/agent/src/main/java/com/alibaba/testable/util/StringUtil.java b/agent/src/main/java/com/alibaba/testable/util/StringUtil.java new file mode 100644 index 0000000..20f2f8d --- /dev/null +++ b/agent/src/main/java/com/alibaba/testable/util/StringUtil.java @@ -0,0 +1,16 @@ +package com.alibaba.testable.util; + +/** + * @author flin + */ +public class StringUtil { + + public static String repeat(String text, int times) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < times; i++) { + sb.append(text); + } + return sb.toString(); + } + +} diff --git a/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java b/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java index 739794b..dd4936b 100644 --- a/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java +++ b/agent/src/test/java/com/alibaba/testable/util/ClassUtilTest.java @@ -9,13 +9,13 @@ class ClassUtilTest { @Test void should_able_to_generate_target_desc() { assertEquals("(Ljava/lang/Class;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.generateTargetDesc("(Ljava/lang/String;)V")); + ClassUtil.getParameterCount("(Ljava/lang/String;)V")); assertEquals("(Ljava/lang/Class;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.generateTargetDesc("(Ljava/lang/String;IDLjava/lang/String;ZLjava/net/URL;)V")); + ClassUtil.getParameterCount("(Ljava/lang/String;IDLjava/lang/String;ZLjava/net/URL;)V")); assertEquals("(Ljava/lang/Class;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.generateTargetDesc("(ZLjava/lang/String;IJFDCSBZ)V")); + ClassUtil.getParameterCount("(ZLjava/lang/String;IJFDCSBZ)V")); assertEquals("(Ljava/lang/Class;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", - ClassUtil.generateTargetDesc("(Ljava/lang/String;[I[Ljava/lang/String;)V")); + ClassUtil.getParameterCount("(Ljava/lang/String;[I[Ljava/lang/String;)V")); } }