fit for static mock method

This commit is contained in:
金戟 2021-02-18 18:15:11 +08:00
parent 8be5550331
commit bdd99577c4
5 changed files with 98 additions and 47 deletions

View File

@ -4,6 +4,7 @@ import com.alibaba.testable.agent.constant.ConstPool;
import com.alibaba.testable.agent.tool.ImmutablePair;
import com.alibaba.testable.agent.util.AnnotationUtil;
import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.agent.util.MethodUtil;
import com.alibaba.testable.core.model.MockScope;
import org.objectweb.asm.Label;
import org.objectweb.asm.Type;
@ -41,6 +42,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
mn.access &= ~ACC_PRIVATE;
mn.access &= ~ACC_PROTECTED;
mn.access |= ACC_PUBLIC;
// below transform order is important
unfoldTargetClass(mn);
injectInvokeRecorder(mn);
injectAssociationChecker(mn);
@ -88,26 +90,57 @@ public class MockClassHandler extends BaseClassWithContextHandler {
}
}
if (targetClassName != null) {
// must get label before method description changed
ImmutablePair<LabelNode, LabelNode> labels = getStartAndEndLabel(mn);
mn.desc = ClassUtil.addParameterAtBegin(mn.desc, targetClassName);
LocalVariableNode thisRef = mn.localVariables.get(0);
mn.localVariables.add(1, new LocalVariableNode("__self", targetClassName, null,
thisRef.start, thisRef.end, 1));
for (int i = 2; i < mn.localVariables.size(); i++) {
int parameterOffset = MethodUtil.isStaticMethod(mn) ? 0 : 1;
mn.localVariables.add(parameterOffset, new LocalVariableNode("__self", targetClassName, null,
labels.left, labels.right, parameterOffset));
for (int i = parameterOffset + 1; i < mn.localVariables.size(); i++) {
mn.localVariables.get(i).index++;
}
for (AbstractInsnNode in : mn.instructions) {
if (in instanceof IincInsnNode) {
((IincInsnNode)in).var++;
} else if (in instanceof VarInsnNode && ((VarInsnNode)in).var > 0) {
} else if (in instanceof VarInsnNode && ((VarInsnNode)in).var >= parameterOffset) {
((VarInsnNode)in).var++;
} else if (in instanceof FrameNode && ((FrameNode)in).type == F_FULL) {
((FrameNode)in).local.add(1, targetClassName);
((FrameNode)in).local.add(parameterOffset, targetClassName);
}
}
mn.maxLocals++;
}
}
private ImmutablePair<LabelNode, LabelNode> getStartAndEndLabel(MethodNode mn) {
if (MethodUtil.isStaticMethod(mn)) {
LabelNode startLabel = null, endLabel = null;
for (AbstractInsnNode n = mn.instructions.getFirst(); n != null; n = n.getNext()) {
if (n instanceof LabelNode) {
startLabel = (LabelNode)n;
break;
}
}
if (ClassUtil.extractParameters(mn.desc).isEmpty()) {
// for method without parameter, should manually add a ending label
endLabel = new LabelNode(new Label());
mn.instructions.add(endLabel);
} else {
// for method with parameters, find the existing ending label
for (AbstractInsnNode n = mn.instructions.getLast(); n != null; n = n.getPrevious()) {
if (n instanceof LabelNode) {
endLabel = (LabelNode)n;
break;
}
}
}
return ImmutablePair.of(startLabel, endLabel);
} else {
LocalVariableNode thisRef = mn.localVariables.get(0);
return ImmutablePair.of(thisRef.start, thisRef.end);
}
}
private void injectAssociationChecker(MethodNode mn) {
if (isGlobalScope(mn)) {
return;
@ -218,7 +251,7 @@ public class MockClassHandler extends BaseClassWithContextHandler {
int size = types.size();
il.add(getIntInsn(size));
il.add(new TypeInsnNode(ANEWARRAY, ClassUtil.CLASS_OBJECT));
int parameterOffset = 1;
int parameterOffset = MethodUtil.isStaticMethod(mn) ? 0 : 1;
for (int i = 0; i < size; i++) {
mn.maxStack += 3;
il.add(new InsnNode(DUP));

View File

@ -2,7 +2,6 @@ package com.alibaba.testable.agent.handler;
import com.alibaba.testable.agent.model.MethodInfo;
import com.alibaba.testable.agent.model.ModifiedInsnNodes;
import com.alibaba.testable.agent.tool.ImmutablePair;
import com.alibaba.testable.agent.util.BytecodeUtil;
import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.core.util.LogUtil;
@ -49,11 +48,11 @@ public class SourceClassHandler extends BaseClassHandler {
}
}
for (MethodNode m : cn.methods) {
transformMethod(cn, m, memberInjectMethods, newOperatorInjectMethods);
transformMethod(m, memberInjectMethods, newOperatorInjectMethods);
}
}
private void transformMethod(ClassNode cn, MethodNode mn, Set<MethodInfo> memberInjectMethods,
private void transformMethod(MethodNode mn, Set<MethodInfo> memberInjectMethods,
Set<MethodInfo> newOperatorInjectMethods) {
LogUtil.diagnose(" Handling method %s", mn.name);
AbstractInsnNode[] instructions = mn.instructions.toArray();
@ -69,12 +68,12 @@ public class SourceClassHandler extends BaseClassHandler {
if (CONSTRUCTOR.equals(node.name)) {
LogUtil.verbose(" Line %d, constructing \"%s\" as \"%s\"", getLineNum(instructions, i),
node.owner, node.desc);
String newOperatorInjectMethodName = getNewOperatorInjectMethodName(newOperatorInjectMethods, node);
if (newOperatorInjectMethodName != null) {
MethodInfo newOperatorInjectMethod = getNewOperatorInjectMethod(newOperatorInjectMethods, node);
if (newOperatorInjectMethod != null) {
// it's a new operation and an inject method for it exist
int rangeStart = getConstructorStart(instructions, node.owner, i);
if (rangeStart >= 0) {
ModifiedInsnNodes modifiedInsnNodes = replaceNewOps(cn, mn, newOperatorInjectMethodName,
ModifiedInsnNodes modifiedInsnNodes = replaceNewOps(mn, newOperatorInjectMethod,
instructions, rangeStart, i);
instructions = modifiedInsnNodes.nodes;
maxStackDiff = Math.max(maxStackDiff, modifiedInsnNodes.stackDiff);
@ -89,7 +88,7 @@ public class SourceClassHandler extends BaseClassHandler {
// it's a member or static method and an inject method for it exist
int rangeStart = getMemberMethodStart(instructions, i);
if (rangeStart >= 0) {
ModifiedInsnNodes modifiedInsnNodes = replaceMemberCallOps(cn, mn, mockMethod,
ModifiedInsnNodes modifiedInsnNodes = replaceMemberCallOps(mn, mockMethod,
instructions, node.owner, node.getOpcode(), rangeStart, i);
instructions = modifiedInsnNodes.nodes;
maxStackDiff = Math.max(maxStackDiff, modifiedInsnNodes.stackDiff);
@ -124,10 +123,10 @@ public class SourceClassHandler extends BaseClassHandler {
return null;
}
private String getNewOperatorInjectMethodName(Set<MethodInfo> newOperatorInjectMethods, MethodInsnNode node) {
private MethodInfo getNewOperatorInjectMethod(Set<MethodInfo> newOperatorInjectMethods, MethodInsnNode node) {
for (MethodInfo m : newOperatorInjectMethods) {
if (m.getDesc().equals(getConstructorInjectDesc(node))) {
return m.getMockName();
return m;
}
}
return null;
@ -198,16 +197,19 @@ public class SourceClassHandler extends BaseClassHandler {
return ClassUtil.getParameterTypes(desc).size() - (ClassUtil.getReturnType(desc).equals(VOID_RES) ? 0 : 1);
}
private ModifiedInsnNodes replaceNewOps(ClassNode cn, MethodNode mn, String newOperatorInjectMethodName,
AbstractInsnNode[] instructions, int start, int end) {
LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start),
newOperatorInjectMethodName);
private ModifiedInsnNodes replaceNewOps(MethodNode mn, MethodInfo newOperatorInjectMethod,
AbstractInsnNode[] instructions, int start, int end) {
String mockMethodName = newOperatorInjectMethod.getMockName();
int invokeOpcode = newOperatorInjectMethod.isStatic() ? INVOKESTATIC : INVOKEVIRTUAL;
LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start), mockMethodName);
String classType = ((TypeInsnNode)instructions[start]).desc;
String constructorDesc = ((MethodInsnNode)instructions[end]).desc;
mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName,
GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false));
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, mockClassName,
newOperatorInjectMethodName, getConstructorInjectDesc(constructorDesc, classType), false));
if (!newOperatorInjectMethod.isStatic()) {
mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName,
GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false));
}
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(invokeOpcode, mockClassName,
mockMethodName, getConstructorInjectDesc(constructorDesc, classType), false));
mn.instructions.remove(instructions[start]);
mn.instructions.remove(instructions[start + 1]);
mn.instructions.remove(instructions[end]);
@ -228,13 +230,14 @@ public class SourceClassHandler extends BaseClassHandler {
ClassUtil.toByteCodeClassName(classType);
}
private ModifiedInsnNodes replaceMemberCallOps(ClassNode cn, MethodNode mn, MethodInfo mockMethod,
AbstractInsnNode[] instructions, String ownerClass,
int opcode, int start, int end) {
private ModifiedInsnNodes replaceMemberCallOps(MethodNode mn, MethodInfo mockMethod, AbstractInsnNode[] instructions,
String ownerClass, int opcode, int start, int end) {
LogUtil.diagnose(" Line %d, mock method \"%s\" used", getLineNum(instructions, start),
mockMethod.getMockName());
mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName,
GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false));
if (!mockMethod.isStatic()) {
mn.instructions.insertBefore(instructions[start], new MethodInsnNode(INVOKESTATIC, mockClassName,
GET_TESTABLE_REF, VOID_ARGS + ClassUtil.toByteCodeClassName(mockClassName), false));
}
if (Opcodes.INVOKESTATIC == opcode || isCompanionMethod(ownerClass, opcode)) {
// append a null value if it was a static invoke or in kotlin companion class
mn.instructions.insertBefore(instructions[start], new InsnNode(ACONST_NULL));
@ -243,25 +246,14 @@ public class SourceClassHandler extends BaseClassHandler {
mn.instructions.remove(instructions[end - 1]);
}
}
// method with @MockMethod will be modified as public access, so INVOKEVIRTUAL is used
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(INVOKEVIRTUAL, mockClassName,
// method with @MockMethod will be modified as public access
int invokeOpcode = mockMethod.isStatic() ? INVOKESTATIC : INVOKEVIRTUAL;
mn.instructions.insertBefore(instructions[end], new MethodInsnNode(invokeOpcode, mockClassName,
mockMethod.getMockName(), mockMethod.getMockDesc(), false));
mn.instructions.remove(instructions[end]);
return new ModifiedInsnNodes(mn.instructions.toArray(), 1);
}
private ImmutablePair<Integer, Integer> findRangeOfInvokerInstance(AbstractInsnNode[] nodes, int start, int end) {
int accumulatedLevelChange = 0;
int edgeIndex = start;
for (int i = start; i < end; i++) {
accumulatedLevelChange -= getStackLevelChange(nodes[i]);
if (accumulatedLevelChange == 1) {
edgeIndex = i;
}
}
return ImmutablePair.of(start, edgeIndex);
}
private boolean isCompanionMethod(String ownerClass, int opcode) {
return Opcodes.INVOKEVIRTUAL == opcode && ClassUtil.isCompanionClassName(ownerClass);
}

View File

@ -25,13 +25,18 @@ public class MethodInfo {
* parameter and return value of the mock method
*/
private final String mockDesc;
/**
* whether mock method is defined as static
*/
private final boolean isStatic;
public MethodInfo(String clazz, String name, String desc, String mockName, String mockDesc) {
public MethodInfo(String clazz, String name, String desc, String mockName, String mockDesc, boolean isStatic) {
this.clazz = clazz;
this.name = name;
this.desc = desc;
this.mockName = mockName;
this.mockDesc = mockDesc;
this.isStatic = isStatic;
}
public String getClazz() {
@ -54,6 +59,10 @@ public class MethodInfo {
return mockDesc;
}
public boolean isStatic() {
return isStatic;
}
@Override
public boolean equals(Object o) {
if (this == o) { return true; }
@ -61,6 +70,7 @@ public class MethodInfo {
MethodInfo that = (MethodInfo)o;
if (isStatic != that.isStatic) { return false; }
if (!clazz.equals(that.clazz)) { return false; }
if (!name.equals(that.name)) { return false; }
if (!desc.equals(that.desc)) { return false; }
@ -75,6 +85,7 @@ public class MethodInfo {
result = 31 * result + desc.hashCode();
result = 31 * result + mockName.hashCode();
result = 31 * result + mockDesc.hashCode();
result = 31 * result + (isStatic ? 1 : 0);
return result;
}
}

View File

@ -17,6 +17,7 @@ import java.util.ArrayList;
import java.util.List;
import static com.alibaba.testable.agent.util.ClassUtil.toDotSeparateFullClassName;
import static com.alibaba.testable.agent.util.MethodUtil.isStaticMethod;
import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
public class MockClassParser {
@ -118,24 +119,25 @@ public class MockClassParser {
private MethodInfo getMethodInfo(MethodNode mn, AnnotationNode an, String targetMethod) {
Type targetType = AnnotationUtil.getAnnotationParameter(an, ConstPool.FIELD_TARGET_CLASS, null, Type.class);
boolean isStatic = isStaticMethod(mn);
if (targetType == null) {
// "targetClass" unset, use first parameter as target class type
ImmutablePair<String, String> methodDescPair = extractFirstParameter(mn.desc);
if (methodDescPair == null) {
return null;
}
return new MethodInfo(methodDescPair.left, targetMethod, methodDescPair.right, mn.name, mn.desc);
return new MethodInfo(methodDescPair.left, targetMethod, methodDescPair.right, mn.name, mn.desc, isStatic);
} else {
// "targetClass" found, use it as target class type
String slashSeparatedName = ClassUtil.toSlashSeparatedName(targetType.getClassName());
return new MethodInfo(slashSeparatedName, targetMethod, mn.desc, mn.name,
ClassUtil.addParameterAtBegin(mn.desc, ClassUtil.toByteCodeClassName(slashSeparatedName)));
ClassUtil.addParameterAtBegin(mn.desc, ClassUtil.toByteCodeClassName(slashSeparatedName)), isStatic);
}
}
private void addMockConstructor(List<MethodInfo> methodInfos, ClassNode cn, MethodNode mn) {
String sourceClassName = ClassUtil.getSourceClassName(cn.name);
methodInfos.add(new MethodInfo(sourceClassName, CONSTRUCTOR, mn.desc, mn.name, mn.desc));
methodInfos.add(new MethodInfo(sourceClassName, CONSTRUCTOR, mn.desc, mn.name, mn.desc, isStaticMethod(mn)));
}
/**

View File

@ -0,0 +1,13 @@
package com.alibaba.testable.agent.util;
import org.objectweb.asm.tree.MethodNode;
import static org.objectweb.asm.Opcodes.ACC_STATIC;
public class MethodUtil {
public static boolean isStaticMethod(MethodNode mn) {
return (mn.access & ACC_STATIC) != 0;
}
}