implement member method substitution for common object member method

This commit is contained in:
金戟 2020-10-13 22:40:59 +08:00
parent bc315b72e3
commit 64227f1ed0
5 changed files with 139 additions and 66 deletions

View File

@ -67,7 +67,11 @@ public class SourceClassHandler extends BaseClassHandler {
// it's a member method of current class and an inject method for it exist
int rangeStart = getMemberMethodStart(instructions, i);
if (rangeStart >= 0) {
if (cn.name.equals(node.owner)) {
instructions = replaceMemberCallOps(cn, mn, instructions, rangeStart, i);
} else {
instructions = replaceCommonCallOps(cn, mn, instructions, node.owner, rangeStart, i);
}
i = rangeStart;
}
} else if (ConstPool.CONSTRUCTOR.equals(node.name)) {
@ -134,7 +138,7 @@ public class SourceClassHandler extends BaseClassHandler {
AbstractInsnNode[] instructions, int start, int end) {
String classType = ((TypeInsnNode)instructions[start]).desc;
String constructorDesc = ((MethodInsnNode)instructions[end]).desc;
String testClassName = StringUtil.getTestClassName(cn.name);
String testClassName = ClassUtil.getTestClassName(cn.name);
mn.instructions.insertBefore(instructions[start], new FieldInsnNode(GETSTATIC, testClassName,
ConstPool.TESTABLE_INJECT_REF, ClassUtil.toByteCodeClassName(testClassName)));
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, testClassName,
@ -153,7 +157,7 @@ public class SourceClassHandler extends BaseClassHandler {
private AbstractInsnNode[] replaceMemberCallOps(ClassNode cn, MethodNode mn, AbstractInsnNode[] instructions,
int start, int end) {
MethodInsnNode method = (MethodInsnNode)instructions[end];
String testClassName = StringUtil.getTestClassName(cn.name);
String testClassName = ClassUtil.getTestClassName(cn.name);
mn.instructions.insertBefore(instructions[start], new FieldInsnNode(GETSTATIC, testClassName,
ConstPool.TESTABLE_INJECT_REF, ClassUtil.toByteCodeClassName(testClassName)));
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, testClassName,
@ -163,4 +167,21 @@ public class SourceClassHandler extends BaseClassHandler {
return mn.instructions.toArray();
}
private AbstractInsnNode[] replaceCommonCallOps(ClassNode cn, MethodNode mn, AbstractInsnNode[] instructions,
String ownerClass, int start, int end) {
mn.maxStack++;
MethodInsnNode method = (MethodInsnNode)instructions[end];
String testClassName = ClassUtil.getTestClassName(cn.name);
mn.instructions.insertBefore(instructions[start], new FieldInsnNode(GETSTATIC, testClassName,
ConstPool.TESTABLE_INJECT_REF, ClassUtil.toByteCodeClassName(testClassName)));
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, testClassName,
method.name, addFirstParameter(method.desc, ownerClass), false));
mn.instructions.remove(instructions[end]);
return mn.instructions.toArray();
}
private String addFirstParameter(String desc, String ownerClass) {
return "(" + ClassUtil.toByteCodeClassName(ownerClass) + desc.substring(1);
}
}

View File

@ -0,0 +1,22 @@
package com.alibaba.testable.agent.model;
/**
* @author flin
*/
public class ImmutablePair<L, R> {
/** Left object */
public final L left;
/** Right object */
public final R right;
public ImmutablePair(L left, R right) {
this.left = left;
this.right = right;
}
public static <L, R> ImmutablePair<L, R> of(L l, R r) {
return new ImmutablePair<L, R>(l, r);
}
}

View File

@ -3,24 +3,34 @@ package com.alibaba.testable.agent.transformer;
import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.handler.SourceClassHandler;
import com.alibaba.testable.agent.handler.TestClassHandler;
import com.alibaba.testable.agent.model.ImmutablePair;
import com.alibaba.testable.agent.model.MethodInfo;
import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.agent.util.StringUtil;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.net.URLClassLoader;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
import static com.alibaba.testable.agent.util.ClassUtil.toSlashSeparatedName;
/**
* @author flin
*/
public class TestableClassTransformer implements ClassFileTransformer {
private final Set<String> loadedClassNames = new HashSet<String>();
private static final String TARGET_CLASS = "targetClass";
private static final String TARGET_METHOD = "targetMethod";
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined,
ProtectionDomain protectionDomain, byte[] classFileBuffer) {
@ -30,12 +40,12 @@ public class TestableClassTransformer implements ClassFileTransformer {
}
List<String> annotations = ClassUtil.getAnnotations(className);
List<String> testAnnotations = ClassUtil.getAnnotations(StringUtil.getTestClassName(className));
List<String> testAnnotations = ClassUtil.getAnnotations(ClassUtil.getTestClassName(className));
try {
if (testAnnotations.contains(ConstPool.ENABLE_TESTABLE)) {
// it's a source class with testable enabled
loadedClassNames.add(className);
List<MethodInfo> injectMethods = ClassUtil.getTestableInjectMethods(StringUtil.getTestClassName(className));
List<MethodInfo> injectMethods = getTestableInjectMethods(ClassUtil.getTestClassName(className));
return new SourceClassHandler(injectMethods).getBytes(className);
} else if (annotations.contains(ConstPool.ENABLE_TESTABLE)) {
// it's a test class with testable enabled
@ -52,4 +62,57 @@ public class TestableClassTransformer implements ClassFileTransformer {
return !(loader instanceof URLClassLoader) || null == className || className.startsWith("jdk/");
}
private List<MethodInfo> getTestableInjectMethods(String className) {
try {
List<MethodInfo> methodInfos = new ArrayList<MethodInfo>();
ClassNode cn = new ClassNode();
new ClassReader(className).accept(cn, 0);
for (MethodNode mn : cn.methods) {
checkMethodAnnotation(cn, methodInfos, mn);
}
return methodInfos;
} catch (Exception e) {
return new ArrayList<MethodInfo>();
}
}
private void checkMethodAnnotation(ClassNode cn, List<MethodInfo> methodInfos, MethodNode mn) {
if (mn.visibleAnnotations == null) {
return;
}
for (AnnotationNode an : mn.visibleAnnotations) {
if (toDotSeparateFullClassName(an.desc).equals(ConstPool.TESTABLE_INJECT)) {
String sourceClassName = ClassUtil.getSourceClassName(cn.name);
String targetClass = getAnnotationParameter(an, TARGET_CLASS, sourceClassName);
String targetMethod = getAnnotationParameter(an, TARGET_METHOD, mn.name);
if (sourceClassName.equals(targetClass)) {
methodInfos.add(new MethodInfo(toSlashSeparatedName(targetClass), targetMethod, mn.desc));
} else {
ImmutablePair<String, String> methodDescPair = extractFirstParameter(mn.desc);
if (methodDescPair != null && methodDescPair.left.equals(ClassUtil.toByteCodeClassName(targetClass))) {
methodInfos.add(new MethodInfo(
toSlashSeparatedName(targetClass), targetMethod, methodDescPair.right));
}
}
break;
}
}
}
private ImmutablePair<String, String> extractFirstParameter(String desc) {
// assume first parameter is a class
int pos = desc.indexOf(";");
return pos < 0 ? null : ImmutablePair.of(desc.substring(1, pos + 1), "(" + desc.substring(pos + 1));
}
private String getAnnotationParameter(AnnotationNode an, String key, String defaultValue) {
if (an.values != null) {
int i = an.values.indexOf(key);
if (i % 2 == 0) {
return (String)an.values.get(i+1);
}
}
return defaultValue;
}
}

View File

@ -1,11 +1,9 @@
package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.model.MethodInfo;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;
import java.util.ArrayList;
import java.util.HashMap;
@ -31,8 +29,6 @@ public class ClassUtil {
private static final char TYPE_ARRAY = '[';
private static final Map<Character, String> TYPE_MAPPING = new HashMap<Character, String>();
private static final String TARGET_CLASS = "targetClass";
private static final String TARGET_METHOD = "targetMethod";
static {
TYPE_MAPPING.put(TYPE_BYTE, "java/lang/Byte");
@ -64,47 +60,24 @@ public class ClassUtil {
}
/**
* Get testable inject method from test class
* @param className test class name
* get test class name from source class name
* @param sourceClassName source class name
*/
public static List<MethodInfo> getTestableInjectMethods(String className) {
try {
List<MethodInfo> methodInfos = new ArrayList<MethodInfo>();
ClassNode cn = new ClassNode();
new ClassReader(className).accept(cn, 0);
for (MethodNode mn : cn.methods) {
checkMethodAnnotation(cn, methodInfos, mn);
}
return methodInfos;
} catch (Exception e) {
return new ArrayList<MethodInfo>();
}
public static String getTestClassName(String sourceClassName) {
return sourceClassName + ConstPool.TEST_POSTFIX;
}
private static void checkMethodAnnotation(ClassNode cn, List<MethodInfo> methodInfos, MethodNode mn) {
if (mn.visibleAnnotations == null) {
return;
}
for (AnnotationNode an : mn.visibleAnnotations) {
if (toDotSeparateFullClassName(an.desc).equals(ConstPool.TESTABLE_INJECT)) {
String targetClass = getAnnotationParameter(an, TARGET_CLASS, StringUtil.getSourceClassName(cn.name));
String targetMethod = getAnnotationParameter(an, TARGET_METHOD, mn.name);
methodInfos.add(new MethodInfo(toSlashSeparateName(targetClass), targetMethod, mn.desc));
break;
}
}
}
private static String getAnnotationParameter(AnnotationNode an, String key, String defaultValue) {
if (an.values != null) {
int i = an.values.indexOf(key);
if (i % 2 == 0) {
return (String)an.values.get(i+1);
}
}
return defaultValue;
/**
* get source class name from test class name
* @param testClassName test class name
*/
public static String getSourceClassName(String testClassName) {
return testClassName.substring(0, testClassName.length() - ConstPool.TEST_POSTFIX.length());
}
/**
* parse method desc, fetch parameter types
*/
public static List<Byte> getParameterTypes(String desc) {
List<Byte> parameterTypes = new ArrayList<Byte>();
boolean travelingClass = false;
@ -127,6 +100,9 @@ public class ClassUtil {
return parameterTypes;
}
/**
* parse method desc, fetch return value types
*/
public static String getReturnType(String desc) {
int returnTypeEdge = desc.lastIndexOf(PARAM_END);
char typeChar = desc.charAt(returnTypeEdge + 1);
@ -141,14 +117,23 @@ public class ClassUtil {
}
}
private static String toSlashSeparateName(String name) {
/**
* convert dot separated name to slash separated name
*/
public static String toSlashSeparatedName(String name) {
return name.replace(ConstPool.DOT, ConstPool.SLASH);
}
/**
* convert dot separated name to byte code class name
*/
public static String toByteCodeClassName(String className) {
return TYPE_CLASS + toSlashSeparateName(className) + CLASS_END;
return TYPE_CLASS + toSlashSeparatedName(className) + CLASS_END;
}
/**
* convert byte code class name to dot separated human readable name
*/
public static String toDotSeparateFullClassName(String className) {
return className.replace(ConstPool.SLASH, ConstPool.DOT).substring(1, className.length() - 1);
}

View File

@ -1,7 +1,5 @@
package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool;
/**
* @author flin
*/
@ -20,20 +18,4 @@ public class StringUtil {
return sb.toString();
}
/**
* get test class name from source class name
* @param sourceClassName source class name
*/
public static String getTestClassName(String sourceClassName) {
return sourceClassName + ConstPool.TEST_POSTFIX;
}
/**
* get source class name from test class name
* @param testClassName test class name
*/
public static String getSourceClassName(String testClassName) {
return testClassName.substring(0, testClassName.length() - ConstPool.TEST_POSTFIX.length());
}
}