fit for static mock method

This commit is contained in:
金戟 2021-02-18 18:15:11 +08:00
parent 8be5550331
commit bdd99577c4
5 changed files with 98 additions and 47 deletions

View File

@ -4,6 +4,7 @@ import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.tool.ImmutablePair; import com.alibaba.testable.agent.tool.ImmutablePair;
import com.alibaba.testable.agent.util.AnnotationUtil; import com.alibaba.testable.agent.util.AnnotationUtil;
import com.alibaba.testable.agent.util.ClassUtil; import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.agent.util.MethodUtil;
import com.alibaba.testable.core.model.MockScope; import com.alibaba.testable.core.model.MockScope;
import org.objectweb.asm.Label; import org.objectweb.asm.Label;
import org.objectweb.asm.Type; import org.objectweb.asm.Type;
@ -41,6 +42,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
mn.access &= ~ACC_PRIVATE; mn.access &= ~ACC_PRIVATE;
mn.access &= ~ACC_PROTECTED; mn.access &= ~ACC_PROTECTED;
mn.access |= ACC_PUBLIC; mn.access |= ACC_PUBLIC;
// below transform order is important
unfoldTargetClass(mn); unfoldTargetClass(mn);
injectInvokeRecorder(mn); injectInvokeRecorder(mn);
injectAssociationChecker(mn); injectAssociationChecker(mn);
@ -88,26 +90,57 @@ public class MockClassHandler extends BaseClassWithContextHandler {
} }
} }
if (targetClassName != null) { if (targetClassName != null) {
// must get label before method description changed
ImmutablePair<LabelNode, LabelNode> labels = getStartAndEndLabel(mn);
mn.desc = ClassUtil.addParameterAtBegin(mn.desc, targetClassName); mn.desc = ClassUtil.addParameterAtBegin(mn.desc, targetClassName);
LocalVariableNode thisRef = mn.localVariables.get(0); int parameterOffset = MethodUtil.isStaticMethod(mn) ? 0 : 1;
mn.localVariables.add(1, new LocalVariableNode("__self", targetClassName, null, mn.localVariables.add(parameterOffset, new LocalVariableNode("__self", targetClassName, null,
thisRef.start, thisRef.end, 1)); labels.left, labels.right, parameterOffset));
for (int i = 2; i < mn.localVariables.size(); i++) { for (int i = parameterOffset + 1; i < mn.localVariables.size(); i++) {
mn.localVariables.get(i).index++; mn.localVariables.get(i).index++;
} }
for (AbstractInsnNode in : mn.instructions) { for (AbstractInsnNode in : mn.instructions) {
if (in instanceof IincInsnNode) { if (in instanceof IincInsnNode) {
((IincInsnNode)in).var++; ((IincInsnNode)in).var++;
} else if (in instanceof VarInsnNode && ((VarInsnNode)in).var > 0) { } else if (in instanceof VarInsnNode && ((VarInsnNode)in).var >= parameterOffset) {
((VarInsnNode)in).var++; ((VarInsnNode)in).var++;
} else if (in instanceof FrameNode && ((FrameNode)in).type == F_FULL) { } else if (in instanceof FrameNode && ((FrameNode)in).type == F_FULL) {
((FrameNode)in).local.add(1, targetClassName); ((FrameNode)in).local.add(parameterOffset, targetClassName);
} }
} }
mn.maxLocals++; mn.maxLocals++;
} }
} }
private ImmutablePair<LabelNode, LabelNode> getStartAndEndLabel(MethodNode mn) {
if (MethodUtil.isStaticMethod(mn)) {
LabelNode startLabel = null, endLabel = null;
for (AbstractInsnNode n = mn.instructions.getFirst(); n != null; n = n.getNext()) {
if (n instanceof LabelNode) {
startLabel = (LabelNode)n;
break;
}
}
if (ClassUtil.extractParameters(mn.desc).isEmpty()) {
// for method without parameter, should manually add a ending label
endLabel = new LabelNode(new Label());
mn.instructions.add(endLabel);
} else {
// for method with parameters, find the existing ending label
for (AbstractInsnNode n = mn.instructions.getLast(); n != null; n = n.getPrevious()) {
if (n instanceof LabelNode) {
endLabel = (LabelNode)n;
break;
}
}
}
return ImmutablePair.of(startLabel, endLabel);
} else {
LocalVariableNode thisRef = mn.localVariables.get(0);
return ImmutablePair.of(thisRef.start, thisRef.end);
}
}
private void injectAssociationChecker(MethodNode mn) { private void injectAssociationChecker(MethodNode mn) {
if (isGlobalScope(mn)) { if (isGlobalScope(mn)) {
return; return;
@ -218,7 +251,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
int size = types.size(); int size = types.size();
il.add(getIntInsn(size)); il.add(getIntInsn(size));
il.add(new TypeInsnNode(ANEWARRAY, ClassUtil.CLASS_OBJECT)); il.add(new TypeInsnNode(ANEWARRAY, ClassUtil.CLASS_OBJECT));
int parameterOffset = 1; int parameterOffset = MethodUtil.isStaticMethod(mn) ? 0 : 1;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
mn.maxStack += 3; mn.maxStack += 3;
il.add(new InsnNode(DUP)); il.add(new InsnNode(DUP));

View File

@ -2,7 +2,6 @@ package com.alibaba.testable.agent.handler;
import com.alibaba.testable.agent.model.MethodInfo; import com.alibaba.testable.agent.model.MethodInfo;
import com.alibaba.testable.agent.model.ModifiedInsnNodes; import com.alibaba.testable.agent.model.ModifiedInsnNodes;
import com.alibaba.testable.agent.tool.ImmutablePair;
import com.alibaba.testable.agent.util.BytecodeUtil; import com.alibaba.testable.agent.util.BytecodeUtil;
import com.alibaba.testable.agent.util.ClassUtil; import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.core.util.LogUtil; import com.alibaba.testable.core.util.LogUtil;
@ -49,11 +48,11 @@ public class SourceClassHandler extends BaseClassHandler {
} }
} }
for (MethodNode m : cn.methods) { for (MethodNode m : cn.methods) {
transformMethod(cn, m, memberInjectMethods, newOperatorInjectMethods); transformMethod(m, memberInjectMethods, newOperatorInjectMethods);
} }
} }
private void transformMethod(ClassNode cn, MethodNode mn, Set<MethodInfo> memberInjectMethods, private void transformMethod(MethodNode mn, Set<MethodInfo> memberInjectMethods,
Set<MethodInfo> newOperatorInjectMethods) { Set<MethodInfo> newOperatorInjectMethods) {
LogUtil.diagnose(" Handling method %s", mn.name); LogUtil.diagnose(" Handling method %s", mn.name);
AbstractInsnNode[] instructions = mn.instructions.toArray(); AbstractInsnNode[] instructions = mn.instructions.toArray();
@ -69,12 +68,12 @@ public class SourceClassHandler extends BaseClassHandler {
if (CONSTRUCTOR.equals(node.name)) { if (CONSTRUCTOR.equals(node.name)) {
LogUtil.verbose(" Line %d, constructing \"%s\" as \"%s\"", getLineNum(instructions, i), LogUtil.verbose(" Line %d, constructing \"%s\" as \"%s\"", getLineNum(instructions, i),
node.owner, node.desc); node.owner, node.desc);
String newOperatorInjectMethodName = getNewOperatorInjectMethodName(newOperatorInjectMethods, node); MethodInfo newOperatorInjectMethod = getNewOperatorInjectMethod(newOperatorInjectMethods, node);
if (newOperatorInjectMethodName != null) { if (newOperatorInjectMethod != null) {
// it's a new operation and an inject method for it exist // it's a new operation and an inject method for it exist
int rangeStart = getConstructorStart(instructions, node.owner, i); int rangeStart = getConstructorStart(instructions, node.owner, i);
if (rangeStart >= 0) { if (rangeStart >= 0) {
ModifiedInsnNodes modifiedInsnNodes = replaceNewOps(cn, mn, newOperatorInjectMethodName, ModifiedInsnNodes modifiedInsnNodes = replaceNewOps(mn, newOperatorInjectMethod,
instructions, rangeStart, i); instructions, rangeStart, i);
instructions = modifiedInsnNodes.nodes; instructions = modifiedInsnNodes.nodes;
maxStackDiff = Math.max(maxStackDiff, modifiedInsnNodes.stackDiff); maxStackDiff = Math.max(maxStackDiff, modifiedInsnNodes.stackDiff);
@ -89,7 +88,7 @@ public class SourceClassHandler extends BaseClassHandler {
// it's a member or static method and an inject method for it exist // it's a member or static method and an inject method for it exist
int rangeStart = getMemberMethodStart(instructions, i); int rangeStart = getMemberMethodStart(instructions, i);
if (rangeStart >= 0) { if (rangeStart >= 0) {
ModifiedInsnNodes modifiedInsnNodes = replaceMemberCallOps(cn, mn, mockMethod, ModifiedInsnNodes modifiedInsnNodes = replaceMemberCallOps(mn, mockMethod,
instructions, node.owner, node.getOpcode(), rangeStart, i); instructions, node.owner, node.getOpcode(), rangeStart, i);
instructions = modifiedInsnNodes.nodes; instructions = modifiedInsnNodes.nodes;
maxStackDiff = Math.max(maxStackDiff, modifiedInsnNodes.stackDiff); maxStackDiff = Math.max(maxStackDiff, modifiedInsnNodes.stackDiff);
@ -124,10 +123,10 @@ public class SourceClassHandler extends BaseClassHandler {
return null; return null;
} }
private String getNewOperatorInjectMethodName(Set<MethodInfo> newOperatorInjectMethods, MethodInsnNode node) { private MethodInfo getNewOperatorInjectMethod(Set<MethodInfo> newOperatorInjectMethods, MethodInsnNode node) {
for (MethodInfo m : newOperatorInjectMethods) { for (MethodInfo m : newOperatorInjectMethods) {
if (m.getDesc().equals(getConstructorInjectDesc(node))) { if (m.getDesc().equals(getConstructorInjectDesc(node))) {
return m.getMockName(); return m;
} }
} }
return null; return null;
@ -198,16 +197,19 @@ public class SourceClassHandler extends BaseClassHandler {
return ClassUtil.getParameterTypes(desc).size() - (ClassUtil.getReturnType(desc).equals(VOID_RES) ? 0 : 1); return ClassUtil.getParameterTypes(desc).size() - (ClassUtil.getReturnType(desc).equals(VOID_RES) ? 0 : 1);
} }
private ModifiedInsnNodes replaceNewOps(ClassNode cn, MethodNode mn, String newOperatorInjectMethodName, private ModifiedInsnNodes replaceNewOps(MethodNode mn, MethodInfo newOperatorInjectMethod,
AbstractInsnNode[] instructions, int start, int end) { AbstractInsnNode[] instructions, int start, int end) {
LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start), String mockMethodName = newOperatorInjectMethod.getMockName();
newOperatorInjectMethodName); int invokeOpcode = newOperatorInjectMethod.isStatic() ? INVOKESTATIC : INVOKEVIRTUAL;
LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start), mockMethodName);
String classType = ((TypeInsnNode)instructions[start]).desc; String classType = ((TypeInsnNode)instructions[start]).desc;
String constructorDesc = ((MethodInsnNode)instructions[end]).desc; String constructorDesc = ((MethodInsnNode)instructions[end]).desc;
if (!newOperatorInjectMethod.isStatic()) {
mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName, mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName,
GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false)); GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false));
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, mockClassName, }
newOperatorInjectMethodName, getConstructorInjectDesc(constructorDesc, classType), false)); mn.instructions.insertBefore(instructions[end], new MethodInsnNode(invokeOpcode, mockClassName,
mockMethodName, getConstructorInjectDesc(constructorDesc, classType), false));
mn.instructions.remove(instructions[start]); mn.instructions.remove(instructions[start]);
mn.instructions.remove(instructions[start + 1]); mn.instructions.remove(instructions[start + 1]);
mn.instructions.remove(instructions[end]); mn.instructions.remove(instructions[end]);
@ -228,13 +230,14 @@ public class SourceClassHandler extends BaseClassHandler {
ClassUtil.toByteCodeClassName(classType); ClassUtil.toByteCodeClassName(classType);
} }
private ModifiedInsnNodes replaceMemberCallOps(ClassNode cn, MethodNode mn, MethodInfo mockMethod, private ModifiedInsnNodes replaceMemberCallOps(MethodNode mn, MethodInfo mockMethod, AbstractInsnNode[] instructions,
AbstractInsnNode[] instructions, String ownerClass, String ownerClass, int opcode, int start, int end) {
int opcode, int start, int end) {
LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start), LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start),
mockMethod.getMockName()); mockMethod.getMockName());
if (!mockMethod.isStatic()) {
mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName, mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName,
GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false)); GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false));
}
if (Opcodes.INVOKESTATIC == opcode || isCompanionMethod(ownerClass, opcode)) { if (Opcodes.INVOKESTATIC == opcode || isCompanionMethod(ownerClass, opcode)) {
// append a null value if it was a static invoke or in kotlin companion class // append a null value if it was a static invoke or in kotlin companion class
mn.instructions.insertBefore(instructions[start], new InsnNode(ACONST_NULL)); mn.instructions.insertBefore(instructions[start], new InsnNode(ACONST_NULL));
@ -243,25 +246,14 @@ public class SourceClassHandler extends BaseClassHandler {
mn.instructions.remove(instructions[end - 1]); mn.instructions.remove(instructions[end - 1]);
} }
} }
// method with @MockMethod will be modified as public access, so INVOKEVIRTUAL is used // method with @MockMethod will be modified as public access
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, mockClassName, int invokeOpcode = mockMethod.isStatic() ? INVOKESTATIC : INVOKEVIRTUAL;
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(invokeOpcode, mockClassName,
mockMethod.getMockName(), mockMethod.getMockDesc(), false)); mockMethod.getMockName(), mockMethod.getMockDesc(), false));
mn.instructions.remove(instructions[end]); mn.instructions.remove(instructions[end]);
return new ModifiedInsnNodes(mn.instructions.toArray(), 1); return new ModifiedInsnNodes(mn.instructions.toArray(), 1);
} }
private ImmutablePair<Integer, Integer> findRangeOfInvokerInstance(AbstractInsnNode[] nodes, int start, int end) {
int accumulatedLevelChange = 0;
int edgeIndex = start;
for (int i = start; i < end; i++) {
accumulatedLevelChange -= getStackLevelChange(nodes[i]);
if (accumulatedLevelChange == 1) {
edgeIndex = i;
}
}
return ImmutablePair.of(start, edgeIndex);
}
private boolean isCompanionMethod(String ownerClass, int opcode) { private boolean isCompanionMethod(String ownerClass, int opcode) {
return Opcodes.INVOKEVIRTUAL == opcode && ClassUtil.isCompanionClassName(ownerClass); return Opcodes.INVOKEVIRTUAL == opcode && ClassUtil.isCompanionClassName(ownerClass);
} }

View File

@ -25,13 +25,18 @@ public class MethodInfo {
* parameter and return value of the mock method * parameter and return value of the mock method
*/ */
private final String mockDesc; private final String mockDesc;
/**
* whether mock method is defined as static
*/
private final boolean isStatic;
public MethodInfo(String clazz, String name, String desc, String mockName, String mockDesc) { public MethodInfo(String clazz, String name, String desc, String mockName, String mockDesc, boolean isStatic) {
this.clazz = clazz; this.clazz = clazz;
this.name = name; this.name = name;
this.desc = desc; this.desc = desc;
this.mockName = mockName; this.mockName = mockName;
this.mockDesc = mockDesc; this.mockDesc = mockDesc;
this.isStatic = isStatic;
} }
public String getClazz() { public String getClazz() {
@ -54,6 +59,10 @@ public class MethodInfo {
return mockDesc; return mockDesc;
} }
public boolean isStatic() {
return isStatic;
}
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {
if (this == o) { return true; } if (this == o) { return true; }
@ -61,6 +70,7 @@ public class MethodInfo {
MethodInfo that = (MethodInfo)o; MethodInfo that = (MethodInfo)o;
if (isStatic != that.isStatic) { return false; }
if (!clazz.equals(that.clazz)) { return false; } if (!clazz.equals(that.clazz)) { return false; }
if (!name.equals(that.name)) { return false; } if (!name.equals(that.name)) { return false; }
if (!desc.equals(that.desc)) { return false; } if (!desc.equals(that.desc)) { return false; }
@ -75,6 +85,7 @@ public class MethodInfo {
result = 31 * result + desc.hashCode(); result = 31 * result + desc.hashCode();
result = 31 * result + mockName.hashCode(); result = 31 * result + mockName.hashCode();
result = 31 * result + mockDesc.hashCode(); result = 31 * result + mockDesc.hashCode();
result = 31 * result + (isStatic ? 1 : 0);
return result; return result;
} }
} }

View File

@ -17,6 +17,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName; import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
import static com.alibaba.testable.agent.util.MethodUtil.isStaticMethod;
import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR; import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
public class MockClassParser { public class MockClassParser {
@ -118,24 +119,25 @@ public class MockClassParser {
private MethodInfo getMethodInfo(MethodNode mn, AnnotationNode an, String targetMethod) { private MethodInfo getMethodInfo(MethodNode mn, AnnotationNode an, String targetMethod) {
Type targetType = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class); Type targetType = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
boolean isStatic = isStaticMethod(mn);
if (targetType == null) { if (targetType == null) {
// "targetClass" unset, use first parameter as target class type // "targetClass" unset, use first parameter as target class type
ImmutablePair<String, String> methodDescPair = extractFirstParameter(mn.desc); ImmutablePair<String, String> methodDescPair = extractFirstParameter(mn.desc);
if (methodDescPair == null) { if (methodDescPair == null) {
return null; return null;
} }
return new MethodInfo(methodDescPair.left, targetMethod, methodDescPair.right, mn.name, mn.desc); return new MethodInfo(methodDescPair.left, targetMethod, methodDescPair.right, mn.name, mn.desc, isStatic);
} else { } else {
// "targetClass" found, use it as target class type // "targetClass" found, use it as target class type
String slashSeparatedName = ClassUtil.toSlashSeparatedName(targetType.getClassName()); String slashSeparatedName = ClassUtil.toSlashSeparatedName(targetType.getClassName());
return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name, return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name,
ClassUtil.addParameterAtBegin(mn.desc, ClassUtil.toByteCodeClassName(slashSeparatedName))); ClassUtil.addParameterAtBegin(mn.desc, ClassUtil.toByteCodeClassName(slashSeparatedName)), isStatic);
} }
} }
private void addMockConstructor(List<MethodInfo> methodInfos, ClassNode cn, MethodNode mn) { private void addMockConstructor(List<MethodInfo> methodInfos, ClassNode cn, MethodNode mn) {
String sourceClassName = ClassUtil.getSourceClassName(cn.name); String sourceClassName = ClassUtil.getSourceClassName(cn.name);
methodInfos.add(new MethodInfo(sourceClassName, CONSTRUCTOR, mn.desc, mn.name, mn.desc)); methodInfos.add(new MethodInfo(sourceClassName, CONSTRUCTOR, mn.desc, mn.name, mn.desc, isStaticMethod(mn)));
} }
/** /**

View File

@ -0,0 +1,13 @@
package com.alibaba.testable.agent.util;
import org.objectweb.asm.tree.MethodNode;
import static org.objectweb.asm.Opcodes.ACC_STATIC;
public class MethodUtil {
public static boolean isStaticMethod(MethodNode mn) {
return (mn.access & ACC_STATIC) != 0;
}
}