refactor transformer

This commit is contained in:
金戟 2021-02-17 13:45:57 +08:00
parent 01bd676df7
commit dab1d36a81
5 changed files with 225 additions and 185 deletions

View File

@ -0,0 +1,151 @@
package com.alibaba.testable.agent.transformer;
import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.model.MethodInfo;
import com.alibaba.testable.agent.tool.ImmutablePair;
import com.alibaba.testable.agent.util.AnnotationUtil;
import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.agent.util.DiagnoseUtil;
import com.alibaba.testable.core.util.LogUtil;
import com.alibaba.testable.core.util.MockContextUtil;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;
import java.util.ArrayList;
import java.util.List;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
public class MockClassParser {
private static final String CLASS_OBJECT = "java/lang/Object";
/**
* Get information of all mock methods
* @param className mock class name
* @return list of mock methods
*/
public List<MethodInfo> getTestableMockMethods(String className) {
List<MethodInfo> methodInfos = new ArrayList<MethodInfo>();
ClassNode cn = ClassUtil.getClassNode(className);
if (cn == null) {
return new ArrayList<MethodInfo>();
}
for (MethodNode mn : getAllMethods(cn)) {
checkMethodAnnotation(cn, methodInfos, mn);
}
LogUtil.diagnose(" Found %d mock methods", methodInfos.size());
return methodInfos;
}
/**
* Check whether any method in specified class has mock-related annotation
*
* @param className class that need to explore
* @return found annotation or not
*/
public boolean isMockClass(String className) {
return MockContextUtil.mockToTests.containsKey(ClassUtil.toDotSeparatedName(className)) ||
hasMockMethod(className);
}
private boolean hasMockMethod(String className) {
ClassNode cn = ClassUtil.getClassNode(className);
if (cn == null) {
return false;
}
DiagnoseUtil.setupByClass(cn);
for (MethodNode mn : cn.methods) {
if (mn.visibleAnnotations != null) {
for (AnnotationNode an : mn.visibleAnnotations) {
String fullClassName = toDotSeparateFullClassName(an.desc);
if (fullClassName.equals(ConstPool.MOCK_METHOD) ||
fullClassName.equals(ConstPool.MOCK_CONSTRUCTOR)) {
return true;
}
}
}
}
return false;
}
private List<MethodNode> getAllMethods(ClassNode cn) {
List<MethodNode> mns = new ArrayList<MethodNode>(cn.methods);
if (cn.superName != null && !cn.superName.equals(CLASS_OBJECT)) {
ClassNode scn = ClassUtil.getClassNode(cn.superName);
if (scn != null) {
mns.addAll(getAllMethods(scn));
}
}
return mns;
}
private void checkMethodAnnotation(ClassNode cn, List<MethodInfo> methodInfos, MethodNode mn) {
if (mn.visibleAnnotations == null) {
return;
}
for (AnnotationNode an : mn.visibleAnnotations) {
String fullClassName = toDotSeparateFullClassName(an.desc);
if (fullClassName.equals(ConstPool.MOCK_CONSTRUCTOR)) {
LogUtil.verbose(" Mock constructor \"%s\" as \"(%s)V\" for \"%s\"", mn.name,
ClassUtil.extractParameters(mn.desc), ClassUtil.getReturnType(mn.desc));
addMockConstructor(methodInfos, cn, mn);
} else if (fullClassName.equals(ConstPool.MOCK_METHOD)) {
LogUtil.verbose(" Mock method \"%s\" as \"%s\"", mn.name, getTargetMethodDesc(mn, an));
String targetMethod = AnnotationUtil.getAnnotationParameter(
an, ConstPool.FIELD_TARGET_METHOD, mn.name, String.class);
if (ConstPool.CONSTRUCTOR.equals(targetMethod)) {
addMockConstructor(methodInfos, cn, mn);
} else {
MethodInfo mi = getMethodInfo(mn, an, targetMethod);
if (mi != null) {
methodInfos.add(mi);
}
}
break;
}
}
}
private String getTargetMethodDesc(MethodNode mn, AnnotationNode mockMethodAnnotation) {
Type type = AnnotationUtil.getAnnotationParameter(mockMethodAnnotation, ConstPool.FIELD_TARGET_CLASS,
null, Type.class);
return type == null ? ClassUtil.removeFirstParameter(mn.desc) : mn.desc;
}
private MethodInfo getMethodInfo(MethodNode mn, AnnotationNode an, String targetMethod) {
Type targetType = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
if (targetType == null) {
// "targetClass" unset, use first parameter as target class type
ImmutablePair<String, String> methodDescPair = extractFirstParameter(mn.desc);
if (methodDescPair == null) {
return null;
}
return new MethodInfo(methodDescPair.left, targetMethod, methodDescPair.right, mn.name, mn.desc);
} else {
// "targetClass" found, use it as target class type
String slashSeparatedName = ClassUtil.toSlashSeparatedName(targetType.getClassName());
return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name,
ClassUtil.addParameterAtBegin(mn.desc, ClassUtil.toByteCodeClassName(slashSeparatedName)));
}
}
private void addMockConstructor(List<MethodInfo> methodInfos, ClassNode cn, MethodNode mn) {
String sourceClassName = ClassUtil.getSourceClassName(cn.name);
methodInfos.add(new MethodInfo(sourceClassName, ConstPool.CONSTRUCTOR, mn.desc, mn.name, mn.desc));
}
/**
* Split desc to "first parameter" and "desc of rest parameters"
* @param desc method desc
*/
private ImmutablePair<String, String> extractFirstParameter(String desc) {
// assume first parameter is a class
int pos = desc.indexOf(";");
return pos < 0 ? null : ImmutablePair.of(desc.substring(2, pos), "(" + desc.substring(pos + 1));
}
}

View File

@ -5,29 +5,22 @@ import com.alibaba.testable.agent.handler.MockClassHandler;
import com.alibaba.testable.agent.handler.SourceClassHandler;
import com.alibaba.testable.agent.handler.TestClassHandler;
import com.alibaba.testable.agent.model.MethodInfo;
import com.alibaba.testable.agent.tool.ImmutablePair;
import com.alibaba.testable.agent.util.AnnotationUtil;
import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.agent.util.GlobalConfig;
import com.alibaba.testable.agent.util.StringUtil;
import com.alibaba.testable.agent.util.*;
import com.alibaba.testable.core.model.ClassType;
import com.alibaba.testable.core.model.LogLevel;
import com.alibaba.testable.core.util.LogUtil;
import com.alibaba.testable.core.util.MockContextUtil;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.InnerClassNode;
import org.objectweb.asm.tree.MethodNode;
import javax.lang.model.type.NullType;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.List;
import static com.alibaba.testable.agent.constant.ConstPool.*;
@ -42,10 +35,8 @@ public class TestableClassTransformer implements ClassFileTransformer {
private static final String FIELD_VALUE = "value";
private static final String FIELD_TREAT_AS = "treatAs";
private static final String FIELD_DIAGNOSE = "diagnose";
private static final String COMMA = ",";
private static final String CLASS_NAME_MOCK = "Mock";
private static final String CLASS_OBJECT = "java/lang/Object";
/**
* Just avoid spend time to scan those surely non-user classes Should keep these lists as tiny as possible
@ -54,6 +45,8 @@ public class TestableClassTransformer implements ClassFileTransformer {
private final String[] BLACKLIST_PREFIXES = new String[] {"jdk/", "java/", "javax/", "com/sun/",
"org/apache/maven/", "com/alibaba/testable/", "junit/", "org/junit/", "org/testng/"};
public MockClassParser mockClassParser = new MockClassParser();
@Override
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined,
ProtectionDomain protectionDomain, byte[] classFileBuffer) {
@ -64,7 +57,7 @@ public class TestableClassTransformer implements ClassFileTransformer {
LogUtil.verbose("Handle class: " + className);
byte[] bytes = null;
try {
if (isMockClass(className)) {
if (mockClassParser.isMockClass(className)) {
// it's a mock class
LogUtil.diagnose("Handling mock class %s", className);
bytes = new MockClassHandler(className).getBytes(classFileBuffer);
@ -80,7 +73,7 @@ public class TestableClassTransformer implements ClassFileTransformer {
mockClass = foundMockForSourceClass(className);
if (mockClass != null) {
// it's a source class with testable enabled
List<MethodInfo> injectMethods = getTestableMockMethods(mockClass);
List<MethodInfo> injectMethods = mockClassParser.getTestableMockMethods(mockClass);
LogUtil.diagnose("Handling source class %s", className);
bytes = new SourceClassHandler(injectMethods, mockClass).getBytes(classFileBuffer);
dumpByte(className, bytes);
@ -127,17 +120,12 @@ public class TestableClassTransformer implements ClassFileTransformer {
return mockClass;
}
mockClass = ClassUtil.getMockClassName(ClassUtil.getSourceClassName(className));
if (isMockClass(mockClass)) {
if (mockClassParser.isMockClass(mockClass)) {
return mockClass;
}
return null;
}
private boolean isMockClass(String className) {
return MockContextUtil.mockToTests.containsKey(ClassUtil.toDotSeparatedName(className)) ||
hasMockMethod(className);
}
private boolean isSystemClass(String className) {
// className can be null for Java 8 lambdas
if (null == className) {
@ -167,85 +155,6 @@ public class TestableClassTransformer implements ClassFileTransformer {
return false;
}
private List<MethodInfo> getTestableMockMethods(String className) {
List<MethodInfo> methodInfos = new ArrayList<MethodInfo>();
ClassNode cn = getClassNode(className);
if (cn == null) {
return new ArrayList<MethodInfo>();
}
for (MethodNode mn : getAllMethods(cn)) {
checkMethodAnnotation(cn, methodInfos, mn);
}
LogUtil.diagnose(" Found %d mock methods", methodInfos.size());
return methodInfos;
}
private List<MethodNode> getAllMethods(ClassNode cn) {
List<MethodNode> mns = new ArrayList<MethodNode>(cn.methods);
if (cn.superName != null && !cn.superName.equals(CLASS_OBJECT)) {
ClassNode scn = getClassNode(cn.superName);
if (scn != null) {
mns.addAll(getAllMethods(scn));
}
}
return mns;
}
private void checkMethodAnnotation(ClassNode cn, List<MethodInfo> methodInfos, MethodNode mn) {
if (mn.visibleAnnotations == null) {
return;
}
for (AnnotationNode an : mn.visibleAnnotations) {
String fullClassName = toDotSeparateFullClassName(an.desc);
if (fullClassName.equals(ConstPool.MOCK_CONSTRUCTOR)) {
LogUtil.verbose(" Mock constructor \"%s\" as \"(%s)V\" for \"%s\"", mn.name,
ClassUtil.extractParameters(mn.desc), ClassUtil.getReturnType(mn.desc));
addMockConstructor(methodInfos, cn, mn);
} else if (fullClassName.equals(ConstPool.MOCK_METHOD)) {
LogUtil.verbose(" Mock method \"%s\" as \"%s\"", mn.name, getTargetMethodDesc(mn, an));
String targetMethod = AnnotationUtil.getAnnotationParameter(
an, ConstPool.FIELD_TARGET_METHOD, mn.name, String.class);
if (ConstPool.CONSTRUCTOR.equals(targetMethod)) {
addMockConstructor(methodInfos, cn, mn);
} else {
MethodInfo mi = getMethodInfo(mn, an, targetMethod);
if (mi != null) {
methodInfos.add(mi);
}
}
break;
}
}
}
private String getTargetMethodDesc(MethodNode mn, AnnotationNode mockMethodAnnotation) {
Type type = AnnotationUtil.getAnnotationParameter(mockMethodAnnotation, ConstPool.FIELD_TARGET_CLASS,
null, Type.class);
return type == null ? ClassUtil.removeFirstParameter(mn.desc) : mn.desc;
}
private MethodInfo getMethodInfo(MethodNode mn, AnnotationNode an, String targetMethod) {
Type targetType = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
if (targetType == null) {
// "targetClass" unset, use first parameter as target class type
ImmutablePair<String, String> methodDescPair = extractFirstParameter(mn.desc);
if (methodDescPair == null) {
return null;
}
return new MethodInfo(methodDescPair.left, targetMethod, methodDescPair.right, mn.name, mn.desc);
} else {
// "targetClass" found, use it as target class type
String slashSeparatedName = ClassUtil.toSlashSeparatedName(targetType.getClassName());
return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name,
ClassUtil.addParameterAtBegin(mn.desc, ClassUtil.toByteCodeClassName(slashSeparatedName)));
}
}
private void addMockConstructor(List<MethodInfo> methodInfos, ClassNode cn, MethodNode mn) {
String sourceClassName = ClassUtil.getSourceClassName(cn.name);
methodInfos.add(new MethodInfo(sourceClassName, ConstPool.CONSTRUCTOR, mn.desc, mn.name, mn.desc));
}
/**
* Read @MockWith annotation upon class to fetch mock class
*
@ -253,7 +162,7 @@ public class TestableClassTransformer implements ClassFileTransformer {
* @return name of mock class, null for not found
*/
private String readMockWithAnnotationAsSourceClass(String className) {
ClassNode cn = getClassNode(className);
ClassNode cn = ClassUtil.getClassNode(className);
if (cn == null) {
return null;
}
@ -267,7 +176,7 @@ public class TestableClassTransformer implements ClassFileTransformer {
* @return name of mock class, null for not found
*/
private String readMockWithAnnotationAndInnerClassAsTestClass(String className) {
ClassNode cn = getClassNode(className);
ClassNode cn = ClassUtil.getClassNode(className);
if (cn == null) {
return null;
}
@ -299,7 +208,7 @@ public class TestableClassTransformer implements ClassFileTransformer {
private String parseMockWithAnnotation(ClassNode cn, ClassType expectedType) {
if (cn.visibleAnnotations != null) {
for (AnnotationNode an : cn.visibleAnnotations) {
setupDiagnose(an);
DiagnoseUtil.setupByAnnotation(an);
if (toDotSeparateFullClassName(an.desc).equals(ConstPool.MOCK_WITH)) {
ClassType type = AnnotationUtil.getAnnotationParameter(an, FIELD_TREAT_AS, ClassType.GuessByName,
ClassType.class);
@ -324,80 +233,8 @@ public class TestableClassTransformer implements ClassFileTransformer {
}
}
/**
* Check whether any method in specified class has mock-related annotation
*
* @param className class that need to explore
* @return found annotation or not
*/
private boolean hasMockMethod(String className) {
ClassNode cn = getClassNode(className);
if (cn == null) {
return false;
}
setupDiagnose(cn);
for (MethodNode mn : cn.methods) {
if (mn.visibleAnnotations != null) {
for (AnnotationNode an : mn.visibleAnnotations) {
String fullClassName = toDotSeparateFullClassName(an.desc);
if (fullClassName.equals(ConstPool.MOCK_METHOD) ||
fullClassName.equals(ConstPool.MOCK_CONSTRUCTOR)) {
return true;
}
}
}
}
return false;
}
private String getInnerMockClassName(String className) {
return className + DOLLAR + CLASS_NAME_MOCK;
}
private ClassNode getClassNode(String className) {
ClassNode cn = new ClassNode();
try {
new ClassReader(className).accept(cn, 0);
} catch (IOException e) {
return null;
}
return cn;
}
private void setupDiagnose(ClassNode cn) {
if (cn.visibleAnnotations == null) {
return;
}
for (AnnotationNode an : cn.visibleAnnotations) {
setupDiagnose(an);
}
}
private void setupDiagnose(AnnotationNode an) {
if (toDotSeparateFullClassName(an.desc).equals(MOCK_WITH)) {
setupDianose(an, FIELD_DIAGNOSE);
}
if (toDotSeparateFullClassName(an.desc).equals(ConstPool.MOCK_DIAGNOSE)) {
setupDianose(an, FIELD_VALUE);
}
}
private void setupDianose(AnnotationNode an, String fieldDiagnose) {
LogLevel level = AnnotationUtil.getAnnotationParameter(an, fieldDiagnose, null, LogLevel.class);
if (level != null) {
LogUtil.setLevel(level == LogLevel.ENABLE ? LogUtil.LogLevel.LEVEL_DIAGNOSE :
(level == LogLevel.VERBOSE ? LogUtil.LogLevel.LEVEL_VERBOSE : LogUtil.LogLevel.LEVEL_MUTE));
}
}
/**
* Split desc to "first parameter" and "desc of rest parameters"
* @param desc method desc
*/
private ImmutablePair<String, String> extractFirstParameter(String desc) {
// assume first parameter is a class
int pos = desc.indexOf(";");
return pos < 0 ? null : ImmutablePair.of(desc.substring(2, pos), "(" + desc.substring(pos + 1));
}
}

View File

@ -1,8 +1,11 @@
package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodInsnNode;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@ -220,15 +223,6 @@ public class ClassUtil {
return toDotSeparatedName(className).substring(1, className.length() - 1);
}
/**
* convert byte code class name to slash separated human readable name
* @param className original name
* @return converted name
*/
public static String toSlashSeparateFullClassName(String className) {
return toSlashSeparatedName(className).substring(1, className.length() - 1);
}
/**
* remove first parameter from method descriptor
* @param desc original descriptor
@ -248,6 +242,21 @@ public class ClassUtil {
return "(" + type + desc.substring(1);
}
/**
* Read class from current context
* @param className class name
* @return loaded class
*/
public static ClassNode getClassNode(String className) {
ClassNode cn = new ClassNode();
try {
new ClassReader(className).accept(cn, 0);
} catch (IOException e) {
return null;
}
return cn;
}
private static String toDescriptor(Byte type, String objectType) {
return "(" + (char)type.byteValue() + ")L" + objectType + ";";
}

View File

@ -0,0 +1,43 @@
package com.alibaba.testable.agent.util;
import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.core.model.LogLevel;
import com.alibaba.testable.core.util.LogUtil;
import org.objectweb.asm.tree.AnnotationNode;
import org.objectweb.asm.tree.ClassNode;
import static com.alibaba.testable.agent.constant.ConstPool.MOCK_WITH;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
public class DiagnoseUtil {
private static final String FIELD_VALUE = "value";
private static final String FIELD_DIAGNOSE = "diagnose";
public static void setupByClass(ClassNode cn) {
if (cn.visibleAnnotations == null) {
return;
}
for (AnnotationNode an : cn.visibleAnnotations) {
setupByAnnotation(an);
}
}
public static void setupByAnnotation(AnnotationNode an) {
if (toDotSeparateFullClassName(an.desc).equals(MOCK_WITH)) {
setupDiagnose(an, FIELD_DIAGNOSE);
}
if (toDotSeparateFullClassName(an.desc).equals(ConstPool.MOCK_DIAGNOSE)) {
setupDiagnose(an, FIELD_VALUE);
}
}
private static void setupDiagnose(AnnotationNode an, String fieldDiagnose) {
LogLevel level = AnnotationUtil.getAnnotationParameter(an, fieldDiagnose, null, LogLevel.class);
if (level != null) {
LogUtil.setLevel(level == LogLevel.ENABLE ? LogUtil.LogLevel.LEVEL_DIAGNOSE :
(level == LogLevel.VERBOSE ? LogUtil.LogLevel.LEVEL_VERBOSE : LogUtil.LogLevel.LEVEL_MUTE));
}
}
}

View File

@ -6,16 +6,16 @@ import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class TestableClassTransformerTest {
class MockClassParserTest {
private TestableClassTransformer testableClassTransformer = new TestableClassTransformer();
private MockClassParser mockClassParser = new MockClassParser();
@Test
void should_split_parameters() {
ImmutablePair<String, String> parameters =
PrivateAccessor.invoke(testableClassTransformer, "extractFirstParameter", "()");
PrivateAccessor.invoke(mockClassParser, "extractFirstParameter", "()");
assertNull(parameters);
parameters = PrivateAccessor.invoke(testableClassTransformer, "extractFirstParameter", "(Lcom.alibaba.demo.Class;ILjava.lang.String;Z)");
parameters = PrivateAccessor.invoke(mockClassParser, "extractFirstParameter", "(Lcom.alibaba.demo.Class;ILjava.lang.String;Z)");
assertNotNull(parameters);
assertEquals("com.alibaba.demo.Class", parameters.left);
assertEquals("(ILjava.lang.String;Z)", parameters.right);