always unfold target class parameter of MockMethod annotation to method parameter

This commit is contained in:
金戟 2021-02-16 16:09:05 +08:00
parent be53ea2d9c
commit d8ffcacdaf
5 changed files with 68 additions and 29 deletions

View File

@ -34,12 +34,46 @@ public class MockClassHandler extends BaseClassWithContextHandler {
mn.access &= ~ACC_PRIVATE;
mn.access &= ~ACC_PROTECTED;
mn.access |= ACC_PUBLIC;
unfoldTargetClass(mn);
injectInvokeRecorder(mn);
handleTestableUtil(mn);
}
}
}
/**
* put targetClass parameter in @MockMethod to first parameter of the mock method
*/
private void unfoldTargetClass(MethodNode mn) {
String targetClassName = null;
for (AnnotationNode an : mn.visibleAnnotations) {
if (ClassUtil.toByteCodeClassName(ConstPool.MOCK_METHOD).equals(an.desc)) {
Type type = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
if (type != null && !type.getClassName().equals(NullType.class.getName())) {
targetClassName = ClassUtil.toByteCodeClassName(type.getClassName());
}
AnnotationUtil.removeAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS);
}
}
if (targetClassName != null) {
mn.desc = ClassUtil.addParameterAtBegin(mn.desc, targetClassName);
LocalVariableNode thisRef = mn.localVariables.get(0);
mn.localVariables.add(1, new LocalVariableNode("__self", targetClassName, null,
thisRef.start, thisRef.end, 1));
for (int i = 2; i < mn.localVariables.size(); i++) {
mn.localVariables.get(i).index++;
}
for (AbstractInsnNode in : mn.instructions) {
if (in instanceof IincInsnNode) {
((IincInsnNode)in).var++;
} else if (in instanceof VarInsnNode) {
((VarInsnNode)in).var++;
}
}
mn.maxLocals++;
}
}
private void addGetInstanceMethod(ClassNode cn) {
MethodNode getInstanceMethod = new MethodNode(ACC_PUBLIC | ACC_STATIC, GET_TESTABLE_REF,
VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), null, null);
@ -100,11 +134,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
} else {
il.add(new InsnNode(ICONST_0));
}
if (isTargetClassInParameter(mn)) {
il.add(new InsnNode(ICONST_1));
} else {
il.add(new InsnNode(ICONST_0));
}
il.add(new InsnNode(ICONST_1));
il.add(new MethodInsnNode(INVOKESTATIC, CLASS_INVOKE_RECORD_UTIL, METHOD_RECORD_MOCK_INVOKE,
SIGNATURE_INVOKE_RECORDER_METHOD, false));
mn.instructions.insertBefore(mn.instructions.get(0), il);
@ -126,18 +156,6 @@ public class MockClassHandler extends BaseClassWithContextHandler {
return false;
}
private boolean isTargetClassInParameter(MethodNode mn) {
for (AnnotationNode an : mn.visibleAnnotations) {
if (ConstPool.MOCK_METHOD.equals(toDotSeparateFullClassName(an.desc))) {
Type type = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
if (type != null && !type.getClassName().equals(NullType.class.getName())) {
return false;
}
}
}
return true;
}
private static ImmutablePair<Integer, Integer> getLoadParameterByteCode(Byte type) {
switch (type) {
case ClassUtil.TYPE_BYTE:

View File

@ -233,24 +233,15 @@ public class SourceClassHandler extends BaseClassHandler {
int opcode, int start, int end) {
LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start),
mockMethod.getMockName());
boolean shouldAppendTypeParameter = !mockMethod.getDesc().equals(mockMethod.getMockDesc());
mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName,
GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false));
if (Opcodes.INVOKESTATIC == opcode || isCompanionMethod(ownerClass, opcode)) {
if (shouldAppendTypeParameter) {
// append a null value if it was a static invoke or in kotlin companion class
mn.instructions.insertBefore(instructions[start], new InsnNode(ACONST_NULL));
}
// append a null value if it was a static invoke or in kotlin companion class
mn.instructions.insertBefore(instructions[start], new InsnNode(ACONST_NULL));
if (ClassUtil.isCompanionClassName(ownerClass)) {
// for kotlin companion class, remove the byte code of reference to "companion" static field
mn.instructions.remove(instructions[end - 1]);
}
} else if (!shouldAppendTypeParameter) {
// remove extra ops code of the mocked instance, which was used as first parameter of mock method
ImmutablePair<Integer, Integer> range = findRangeOfInvokerInstance(instructions, start, end);
for (int i = range.left; i <= range.right; i++) {
mn.instructions.remove(instructions[i]);
}
}
// method with @MockMethod will be modified as public access, so INVOKEVIRTUAL is used
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, mockClassName,

View File

@ -230,7 +230,8 @@ public class TestableClassTransformer implements ClassFileTransformer {
} else {
// "targetClass" found, use it as target class type
String slashSeparatedName = ClassUtil.toSlashSeparatedName(targetType.getClassName());
return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name, mn.desc);
return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name,
ClassUtil.addParameterAtBegin(mn.desc, ClassUtil.toByteCodeClassName(slashSeparatedName)));
}
}

View File

@ -40,4 +40,23 @@ public class AnnotationUtil {
return defaultValue;
}
/**
* Remove specified parameter from annotation
* @param an annotation node
* @param key name of parameter to remove
* @return true - success, false - not found
*/
public static boolean removeAnnotationParameter(AnnotationNode an, String key) {
if (an.values == null) {
return false;
}
for (int i = 0; i < an.values.size(); i += 2) {
if (an.values.get(i).equals(key)) {
an.values.remove(i + 1);
an.values.remove(i);
return true;
}
}
return false;
}
}

View File

@ -238,6 +238,16 @@ public class ClassUtil {
return "(" + desc.substring(desc.indexOf(";") + 1);
}
/**
* add extra parameter to the beginning of method descriptor
* @param desc original descriptor
* @param type byte code class name
* @return descriptor with specified parameter at begin
*/
public static String addParameterAtBegin(String desc, String type) {
return "(" + type + desc.substring(1);
}
private static String toDescriptor(Byte type, String objectType) {
return "(" + (char)type.byteValue() + ")L" + objectType + ";";
}