exact match substitution method

This commit is contained in:
金戟 2020-07-27 23:13:55 +08:00
parent 4918a17a08
commit ee275a319a
8 changed files with 138 additions and 43 deletions

View File

@ -1,8 +1,5 @@
package com.alibaba.testable.agent.constant; package com.alibaba.testable.agent.constant;
import java.util.ArrayList;
import java.util.List;
/** /**
* @author flin * @author flin
*/ */
@ -13,9 +10,6 @@ public class ConstPool {
public static final String TEST_POSTFIX = "Test"; public static final String TEST_POSTFIX = "Test";
public static final List<String> SYS_CLASSES = new ArrayList<String>(); public static final String ENABLE_TESTABLE = "com.alibaba.testable.core.annotation.EnableTestable";
static { public static final String TESTABLE_INJECT = "com.alibaba.testable.core.annotation.TestableInject";
SYS_CLASSES.add("java/lang/StringBuilder");
}
} }

View File

@ -1,12 +1,15 @@
package com.alibaba.testable.agent.handler; package com.alibaba.testable.agent.handler;
import com.alibaba.testable.agent.constant.ConstPool; import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.model.MethodInfo;
import com.alibaba.testable.agent.util.ClassUtil; import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.agent.util.CollectionUtil;
import com.alibaba.testable.agent.util.StringUtil; import com.alibaba.testable.agent.util.StringUtil;
import org.objectweb.asm.Opcodes; import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type; import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*; import org.objectweb.asm.tree.*;
import java.util.ArrayList;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -24,33 +27,42 @@ public class SourceClassHandler extends ClassHandler {
private static final String METHOD_DESC_PREFIX = "(Ljava/lang/Object;Ljava/lang/String;"; private static final String METHOD_DESC_PREFIX = "(Ljava/lang/Object;Ljava/lang/String;";
private static final String OBJECT_DESC = "Ljava/lang/Object;"; private static final String OBJECT_DESC = "Ljava/lang/Object;";
private static final String METHOD_DESC_POSTFIX = ")Ljava/lang/Object;"; private static final String METHOD_DESC_POSTFIX = ")Ljava/lang/Object;";
private List<MethodInfo> injectMethods;
public SourceClassHandler(List<MethodInfo> injectMethods) {
this.injectMethods = injectMethods;
}
@Override @Override
protected void transform(ClassNode cn) { protected void transform(ClassNode cn) {
Set<String> methodNames = new HashSet<String>(); List<MethodInfo> methods = new ArrayList<MethodInfo>();
for (MethodNode m : cn.methods) { for (MethodNode m : cn.methods) {
if (!CONSTRUCTOR.equals(m.name)) { if (!CONSTRUCTOR.equals(m.name)) {
methodNames.add(m.name); methods.add(new MethodInfo(m.name, m.desc));
} }
} }
Set<MethodInfo> memberInjectMethods = CollectionUtil.getCrossSet(methods, injectMethods);
Set<MethodInfo> newOperatorInjectMethods = CollectionUtil.getMinusSet(injectMethods, memberInjectMethods);
for (MethodNode m : cn.methods) { for (MethodNode m : cn.methods) {
transformMethod(cn, m, methodNames); transformMethod(cn, m, memberInjectMethods, MethodInfo.descSet(newOperatorInjectMethods));
} }
} }
private void transformMethod(ClassNode cn, MethodNode mn, Set<String> methodNames) { private void transformMethod(ClassNode cn, MethodNode mn, Set<MethodInfo> memberInjectMethods,
Set<String> newOperatorInjectDesc) {
AbstractInsnNode[] instructions = mn.instructions.toArray(); AbstractInsnNode[] instructions = mn.instructions.toArray();
int i = 0; int i = 0;
do { do {
if (instructions[i].getOpcode() == Opcodes.INVOKESPECIAL) { if (instructions[i].getOpcode() == Opcodes.INVOKESPECIAL) {
MethodInsnNode node = (MethodInsnNode)instructions[i]; MethodInsnNode node = (MethodInsnNode)instructions[i];
if (cn.name.equals(node.owner) && methodNames.contains(node.name)) { if (cn.name.equals(node.owner) && memberInjectMethods.contains(new MethodInfo(node.name, node.desc))) {
int rangeStart = getMemberMethodStart(instructions, i); int rangeStart = getMemberMethodStart(instructions, i);
if (rangeStart >= 0) { if (rangeStart >= 0) {
instructions = replaceMemberCallOps(mn, instructions, rangeStart, i); instructions = replaceMemberCallOps(mn, instructions, rangeStart, i);
i = rangeStart; i = rangeStart;
} }
} else if (CONSTRUCTOR.equals(node.name) && !ConstPool.SYS_CLASSES.contains(node.owner)) { } else if (CONSTRUCTOR.equals(node.name) &&
newOperatorInjectDesc.contains(getConstructorInjectDesc(node))) {
int rangeStart = getConstructorStart(instructions, node.owner, i); int rangeStart = getConstructorStart(instructions, node.owner, i);
if (rangeStart >= 0) { if (rangeStart >= 0) {
instructions = replaceNewOps(mn, instructions, rangeStart, i); instructions = replaceNewOps(mn, instructions, rangeStart, i);
@ -62,6 +74,11 @@ public class SourceClassHandler extends ClassHandler {
} while (i < instructions.length); } while (i < instructions.length);
} }
private String getConstructorInjectDesc(MethodInsnNode constructorNode) {
return constructorNode.desc.substring(0, constructorNode.desc.length() - 1) +
ClassUtil.toByteCodeClassName(constructorNode.owner);
}
private int getConstructorStart(AbstractInsnNode[] instructions, String target, int rangeEnd) { private int getConstructorStart(AbstractInsnNode[] instructions, String target, int rangeEnd) {
for (int i = rangeEnd - 1; i >= 0; i--) { for (int i = rangeEnd - 1; i >= 0; i--) {
if (instructions[i].getOpcode() == Opcodes.NEW && ((TypeInsnNode)instructions[i]).desc.equals(target)) { if (instructions[i].getOpcode() == Opcodes.NEW && ((TypeInsnNode)instructions[i]).desc.equals(target)) {

View File

@ -1,26 +1,28 @@
package com.alibaba.testable.agent.model; package com.alibaba.testable.agent.model;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
/** /**
* @author flin * @author flin
*/ */
public class MethodInfo { public class MethodInfo {
private int access;
private String name; private String name;
private String desc; private String desc;
private String signature;
private String[] exceptions;
public MethodInfo(int access, String name, String desc, String signature, String[] exceptions) { public MethodInfo(String name, String desc) {
this.access = access;
this.name = name; this.name = name;
this.desc = desc; this.desc = desc;
this.signature = signature;
this.exceptions = exceptions;
} }
public int getAccess() { public static Set<String> descSet(Collection<MethodInfo> methodInfos) {
return access; Set<String> set = new HashSet<String>();
for (MethodInfo m : methodInfos) {
set.add(m.desc);
}
return set;
} }
public String getName() { public String getName() {
@ -31,11 +33,21 @@ public class MethodInfo {
return desc; return desc;
} }
public String getSignature() { @Override
return signature; public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
MethodInfo that = (MethodInfo)o;
return name.equals(that.name) && desc.equals(that.desc);
} }
public String[] getExceptions() { @Override
return exceptions; public int hashCode() {
return 31 * name.hashCode() + desc.hashCode();
} }
} }

View File

@ -3,6 +3,7 @@ package com.alibaba.testable.agent.transformer;
import com.alibaba.testable.agent.constant.ConstPool; import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.handler.SourceClassHandler; import com.alibaba.testable.agent.handler.SourceClassHandler;
import com.alibaba.testable.agent.handler.TestClassHandler; import com.alibaba.testable.agent.handler.TestClassHandler;
import com.alibaba.testable.agent.model.MethodInfo;
import com.alibaba.testable.agent.util.ClassUtil; import com.alibaba.testable.agent.util.ClassUtil;
import java.io.IOException; import java.io.IOException;
@ -18,9 +19,6 @@ import java.util.Set;
*/ */
public class TestableClassTransformer implements ClassFileTransformer { public class TestableClassTransformer implements ClassFileTransformer {
private static final String ENABLE_TESTABLE = "com.alibaba.testable.core.annotation.EnableTestable";
private static final String ENABLE_TESTABLE_INJECT = "com.alibaba.testable.core.annotation.EnableTestableInject";
private static final Set<String> loadedClassNames = new HashSet<String>(); private static final Set<String> loadedClassNames = new HashSet<String>();
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined,
@ -33,10 +31,11 @@ public class TestableClassTransformer implements ClassFileTransformer {
List<String> annotations = ClassUtil.getAnnotations(className); List<String> annotations = ClassUtil.getAnnotations(className);
List<String> testAnnotations = ClassUtil.getAnnotations(className + ConstPool.TEST_POSTFIX); List<String> testAnnotations = ClassUtil.getAnnotations(className + ConstPool.TEST_POSTFIX);
try { try {
if (annotations.contains(ENABLE_TESTABLE_INJECT) || testAnnotations.contains(ENABLE_TESTABLE)) { if (testAnnotations.contains(ConstPool.ENABLE_TESTABLE)) {
loadedClassNames.add(className); loadedClassNames.add(className);
return new SourceClassHandler().getBytes(className); List<MethodInfo> injectMethods = ClassUtil.getTestableInjectMethods(className + ConstPool.TEST_POSTFIX);
} else if (annotations.contains(ENABLE_TESTABLE)) { return new SourceClassHandler(injectMethods).getBytes(className);
} else if (annotations.contains(ConstPool.ENABLE_TESTABLE)) {
loadedClassNames.add(className); loadedClassNames.add(className);
return new TestClassHandler().getBytes(className); return new TestClassHandler().getBytes(className);
} }

View File

@ -1,9 +1,11 @@
package com.alibaba.testable.agent.util; package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool; import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.model.MethodInfo;
import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassReader;
import org.objectweb.asm.tree.AnnotationNode; import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode; import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -37,8 +39,7 @@ public class ClassUtil {
ClassNode cn = new ClassNode(); ClassNode cn = new ClassNode();
new ClassReader(className).accept(cn, 0); new ClassReader(className).accept(cn, 0);
for (AnnotationNode an : cn.visibleAnnotations) { for (AnnotationNode an : cn.visibleAnnotations) {
String annotationName = an.desc.replace(ConstPool.SLASH, ConstPool.DOT).substring(1, an.desc.length() - 1); annotations.add(toDotSeparateFullClassName(an.desc));
annotations.add(annotationName);
} }
return annotations; return annotations;
} catch (Exception e) { } catch (Exception e) {
@ -46,6 +47,32 @@ public class ClassUtil {
} }
} }
public static List<MethodInfo> getTestableInjectMethods(String className) {
try {
List<MethodInfo> methodInfos = new ArrayList<MethodInfo>();
ClassNode cn = new ClassNode();
new ClassReader(className).accept(cn, 0);
for (MethodNode mn : cn.methods) {
checkMethodAnnotation(methodInfos, mn);
}
return methodInfos;
} catch (Exception e) {
return new ArrayList<MethodInfo>();
}
}
private static void checkMethodAnnotation(List<MethodInfo> methodInfos, MethodNode mn) {
if (mn.visibleAnnotations == null) {
return;
}
for (AnnotationNode an : mn.visibleAnnotations) {
if (toDotSeparateFullClassName(an.desc).equals(ConstPool.TESTABLE_INJECT)) {
methodInfos.add(new MethodInfo(mn.name, mn.desc));
break;
}
}
}
public static List<Byte> getParameterTypes(String desc) { public static List<Byte> getParameterTypes(String desc) {
List<Byte> parameterTypes = new ArrayList<Byte>(); List<Byte> parameterTypes = new ArrayList<Byte>();
boolean travelingClass = false; boolean travelingClass = false;
@ -101,4 +128,9 @@ public class ClassUtil {
public static String toByteCodeClassName(String className) { public static String toByteCodeClassName(String className) {
return TYPE_CLASS + className.replace(ConstPool.DOT, ConstPool.SLASH) + CLASS_END; return TYPE_CLASS + className.replace(ConstPool.DOT, ConstPool.SLASH) + CLASS_END;
} }
public static String toDotSeparateFullClassName(String className) {
return className.replace(ConstPool.SLASH, ConstPool.DOT).substring(1, className.length() - 1);
}
} }

View File

@ -1,18 +1,20 @@
package com.alibaba.testable.agent.util; package com.alibaba.testable.agent.util;
import java.util.ArrayList; import java.util.*;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
/** /**
* @author flin * @author flin
*/ */
public class CollectionUtil { public class CollectionUtil {
public static boolean containsAny(Collection hostContainer, Collection itemsToFind) { /**
for (Object o : hostContainer) { * Check two collection has any equaled item
for (Object i : itemsToFind) { * @param collectionLeft the first collection
* @param collectionRight the second collection
*/
public static boolean containsAny(Collection collectionLeft, Collection collectionRight) {
for (Object o : collectionLeft) {
for (Object i : collectionRight) {
if (o.equals(i)) { if (o.equals(i)) {
return true; return true;
} }
@ -21,10 +23,44 @@ public class CollectionUtil {
return false; return false;
} }
/**
* Generate a list of item
* @param items elements to add
*/
public static <T> List<T> listOf(T... items) { public static <T> List<T> listOf(T... items) {
List<T> list = new ArrayList<T>(items.length); List<T> list = new ArrayList<T>(items.length);
Collections.addAll(list, items); Collections.addAll(list, items);
return list; return list;
} }
/**
* Get cross set of two collections
* @param collectionLeft the first collection
* @param collectionRight the second collection
*/
public static <T> Set<T> getCrossSet(Collection<T> collectionLeft, Collection<T> collectionRight) {
Set<T> crossSet = new HashSet<T>();
for (T i : collectionLeft) {
if (collectionRight.contains(i)) {
crossSet.add(i);
}
}
return crossSet;
}
/**
* Get minus set of two collections
* @param collectionRaw original collection
* @param collectionMinus items to remove
*/
public static <T> Set<T> getMinusSet(Collection<T> collectionRaw, Collection<T> collectionMinus) {
Set<T> crossSet = new HashSet<T>();
for (T i : collectionRaw) {
if (!collectionMinus.contains(i)) {
crossSet.add(i);
}
}
return crossSet;
}
} }

View File

@ -12,4 +12,9 @@ import java.lang.annotation.*;
@Documented @Documented
public @interface EnableTestableInject { public @interface EnableTestableInject {
/**
* Test class names
*/
String[] value();
} }

View File

@ -7,7 +7,7 @@ import java.lang.annotation.*;
* *
* @author flin * @author flin
*/ */
@Retention(RetentionPolicy.SOURCE) @Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD) @Target(ElementType.METHOD)
@Documented @Documented
public @interface TestableInject { public @interface TestableInject {