more effective check

This commit is contained in:
金戟 2020-07-21 12:19:08 +08:00
parent 592731a3e6
commit 6103a691ce

View File

@ -1,6 +1,5 @@
package com.alibaba.testable.transformer; package com.alibaba.testable.transformer;
import com.alibaba.testable.model.MethodInfo;
import com.alibaba.testable.visitor.MethodRecordVisitor; import com.alibaba.testable.visitor.MethodRecordVisitor;
import com.alibaba.testable.visitor.TestableVisitor; import com.alibaba.testable.visitor.TestableVisitor;
import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassReader;
@ -8,10 +7,9 @@ import org.objectweb.asm.ClassWriter;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.instrument.ClassFileTransformer; import java.lang.instrument.ClassFileTransformer;
import java.lang.reflect.Field;
import java.security.ProtectionDomain; import java.security.ProtectionDomain;
import java.util.List; import java.util.HashSet;
import java.util.Vector; import java.util.Set;
public class TestableFileTransformer implements ClassFileTransformer { public class TestableFileTransformer implements ClassFileTransformer {
@ -20,6 +18,8 @@ public class TestableFileTransformer implements ClassFileTransformer {
private static final String SLASH = "/"; private static final String SLASH = "/";
private static final String TEST_POSTFIX = "Test"; private static final String TEST_POSTFIX = "Test";
private static Set<String> loadedClassNames = new HashSet<String>();
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined,
ProtectionDomain protectionDomain, byte[] classfileBuffer) { ProtectionDomain protectionDomain, byte[] classfileBuffer) {
if (null == loader || null == className) { if (null == loader || null == className) {
@ -27,8 +27,8 @@ public class TestableFileTransformer implements ClassFileTransformer {
return null; return null;
} }
String dotClassName = className.replace(SLASH, DOT); String dotClassName = className.replace(SLASH, DOT);
MethodRecordVisitor methodRecordVisitor = getMemberMethods(classfileBuffer, loadedClassNames.add(dotClassName);
isTestClassTestable(loader, dotClassName)); MethodRecordVisitor methodRecordVisitor = getMemberMethods(classfileBuffer, checkTestClass(dotClassName));
if (!methodRecordVisitor.isNeedTransform()) { if (!methodRecordVisitor.isNeedTransform()) {
// Neither EnableTestable on test class, nor EnableTestableInject on source class // Neither EnableTestable on test class, nor EnableTestableInject on source class
return null; return null;
@ -40,33 +40,21 @@ public class TestableFileTransformer implements ClassFileTransformer {
return writer.toByteArray(); return writer.toByteArray();
} }
private boolean isTestClassTestable(ClassLoader loader, String dotClassName) { private boolean checkTestClass(String dotClassName) {
boolean needTransform = false; String testClassName = dotClassName + TEST_POSTFIX;
try { if (loadedClassNames.contains(testClassName)) {
Field classesField = ClassLoader.class.getDeclaredField("classes"); try {
classesField.setAccessible(true); Class<?> testClazz = Class.forName(testClassName);
Vector<Class> classesVector = (Vector<Class>)classesField.get(loader); for (Annotation a : testClazz.getAnnotations()) {
if (null != classesVector) { if (a.annotationType().getName().equals(ENABLE_TESTABLE)) {
for (Class c : classesVector) { return true;
String testClassName = dotClassName + TEST_POSTFIX;
if (c.getName().endsWith(testClassName)) {
Class<?> testClazz = Class.forName(testClassName);
for (Annotation a : testClazz.getAnnotations()) {
if (a.annotationType().getName().equals(ENABLE_TESTABLE)) {
needTransform = true;
}
}
} }
} }
} catch (ClassNotFoundException e) {
return false;
} }
} catch (NoSuchFieldException e) {
return false;
} catch (IllegalAccessException e) {
return false;
} catch (ClassNotFoundException e) {
return false;
} }
return needTransform; return false;
} }
private MethodRecordVisitor getMemberMethods(byte[] classfileBuffer, boolean needTransform) { private MethodRecordVisitor getMemberMethods(byte[] classfileBuffer, boolean needTransform) {