Revert "Fixed method reference"

This commit is contained in:
Fan Lin 2021-09-28 20:49:31 +08:00 committed by GitHub
parent 8fbd459280
commit c2e45de4f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 6 additions and 578 deletions

View File

@ -1,9 +0,0 @@
package com.alibaba.demo.lambda;
/**
* @author jim
*/
@FunctionalInterface
public interface Function1Throwable<T, R> {
R apply(T t) throws Throwable;
}

View File

@ -1,218 +0,0 @@
package com.alibaba.demo.lambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
/**
* @author jim
*/
@SuppressWarnings("unused")
public class LambdaDemo {
public void methodReference() {
consumesRun(this::run);
}
private void consumesRun(Runnable r) {
r.run();
}
private void run() {
blackHole();
}
public String methodReference0() {
return consumes0(this::function0);
}
private String consumes0(Supplier<String> function0) {
return function0.get();
}
private String function0() {
return "Hello";
}
public String methodReference1() {
return consumes1(this::function1);
}
private String consumes1(Function<Integer, String> function) {
return function.apply(1);
}
private String function1(Integer i) {
return String.valueOf(i);
}
public String methodReferenceThrows() {
return consumesThrows(this::function1Throwable);
}
private String consumesThrows(Function1Throwable<Integer, String> function) {
try {
return function.apply(1);
}catch (Throwable e) {
e.printStackTrace();
}
return null;
}
@SuppressWarnings("RedundantThrows")
private String function1Throwable(Integer i) throws Throwable{
return String.valueOf(i);
}
public String methodReference2() {
return consumes2(this::function2);
}
private String consumes2(BiFunction<Integer, Double, String> function) {
return function.apply(1, .2);
}
private String function2(Integer i, Double d) {
return i + String.valueOf(d);
}
public String staticMethodReference1() {
return consumes1(StaticMethod::function1);
}
public String staticMethodReference2() {
return consumes2(StaticMethod::function2);
}
public void lambdaRun() {
consumes(() -> System.out.println("lambdaRun"));
}
private void consumes(Runnable o) {
o.run();
}
public void methodReferenceNew() {
Object o = consumes(Object::new);
blackHole(o);
}
private <T> T consumes(Supplier<T> s) {
return s.get();
}
private void blackHole(Object... ignore) {}
public void array() {
Function<Boolean[], Boolean[]> arrayBooleanFunction = this::arrayBooleanFunction;
Function<boolean[], boolean[]> arrayBooleanFunction1 = this::arrayBoolFunction;
Function<Byte[], Byte[]> byteFunction = this::arrayByteFunction;
Function<byte[], byte[]> byteFunction1 = this::arrayByteFunction;
Function<Character[], Character[]> charFunction = this::arrayCharFunction;
Function<char[], char[]> charFunction1 = this::arrayCharFunction;
Function<Short[], Short[]> shortFunction = this::arrayShortFunction;
Function<short[], short[]> shortFunction1 = this::arrayShortFunction;
Function<int[], int[]> intFunction = this::arrayIntFunction;
Function<Integer[], Integer[]> intFunction1 = this::arrayIntegerFunction;
Function<long[], long[]> longFunction = this::arrayLongFunction;
Function<Long[], Long[]> longFunction1 = this::arrayLongFunction;
Function<Float[], Float[]> floatFunction = this::arrayFloatFunction;
Function<float[], float[]> floatFunction1 = this::arrayFloatFunction;
Function<Double[], Double[]> doubleFunction = this::arrayDoubleFunction;
Function<double[], double[]> doubleFunction1 = this::arrayDoubleFunction;
blackHole(arrayBooleanFunction, arrayBooleanFunction1,
byteFunction, byteFunction1, charFunction, charFunction1, shortFunction, shortFunction1,
intFunction, intFunction1, longFunction, longFunction1, floatFunction, floatFunction1, doubleFunction,
doubleFunction1
);
}
private int[] arrayIntFunction(int[] arg) {
return arg;
}
private Integer[] arrayIntegerFunction(Integer[] arg) {
return arg;
}
private boolean[] arrayBoolFunction(boolean[] arg) {
return arg;
}
private Boolean[] arrayBooleanFunction(Boolean[] arg) {
return arg;
}
private byte[] arrayByteFunction(byte[] arg) {
return arg;
}
private Byte[] arrayByteFunction(Byte[] arg) {
return arg;
}
private char[] arrayCharFunction(char[] arg) {
return arg;
}
private Character[] arrayCharFunction(Character[] arg) {
return arg;
}
private short[] arrayShortFunction(short[] arg) {
return arg;
}
private Short[] arrayShortFunction(Short[] arg) {
return arg;
}
private long[] arrayLongFunction(long[] arg) {
return arg;
}
private Long[] arrayLongFunction(Long[] arg) {
return arg;
}
private float[] arrayFloatFunction(float[] arg) {
return arg;
}
private Float[] arrayFloatFunction(Float[] arg) {
return arg;
}
private double[] arrayDoubleFunction(double[] arg) {
return arg;
}
private Double[] arrayDoubleFunction(Double[] arg) {
return arg;
}
public void generic() {
Function<?, ?> genericFunction = this::genericFunction;
blackHole(genericFunction);
}
public <T, R> R genericFunction(T arg) {
//noinspection unchecked
return (R)arg;
}
private void collects() {
long l = Stream.of("1", "2", "3")
.filter(v -> !"2".equals(v))
.map(Long::parseLong)
.peek(this::blackHole)
.map(v -> new ArrayList<Long>(){{add(v);}})
.flatMap(Collection::stream)
.mapToLong(Long::valueOf)
.sum();
blackHole(l);
}
}

View File

@ -1,15 +0,0 @@
package com.alibaba.demo.lambda;
/**
* @author jim
*/
public class StaticMethod {
public static String function1(Integer i) {
return "static" + i;
}
public static String function2(Integer i, Double d) {
return "static" + i + d;
}
}

View File

@ -1,89 +0,0 @@
package com.alibaba.demo.lambda;
import com.alibaba.testable.core.annotation.MockDiagnose;
import com.alibaba.testable.core.annotation.MockMethod;
import com.alibaba.testable.core.model.LogLevel;
import org.junit.jupiter.api.Test;
import static com.alibaba.testable.core.matcher.InvokeVerifier.verify;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author zcbbpo
*/
public class LambdaDemoTest {
private LambdaDemo lambdaDemo = new LambdaDemo();
@SuppressWarnings("unused")
@MockDiagnose(LogLevel.VERBOSE)
public static class Mock {
@MockMethod(targetClass = LambdaDemo.class, targetMethod = "run")
private void mockRun() {
}
@MockMethod(targetClass = LambdaDemo.class)
private String function0() {
return "mock_function0";
}
@MockMethod(targetClass = LambdaDemo.class)
private String function1(Integer i) {
return "mock_function1";
}
@MockMethod(targetClass = LambdaDemo.class)
private String function2(Integer i, Double d) {
return "mock_function2";
}
@SuppressWarnings("RedundantThrows")
@MockMethod(targetClass = LambdaDemo.class)
private String function1Throwable(Integer i) throws Throwable{
return "mock_function1Throwable";
}
@MockMethod(targetClass = StaticMethod.class, targetMethod = "function1")
public static String staticFunction1(Integer i) {
return "mock_staticFunction1";
}
}
@Test
public void shouldMockRun() {
lambdaDemo.methodReference();
verify("mockRun").withTimes(1);
}
@Test
public void shouldMockFunction0() {
String s = lambdaDemo.methodReference0();
assertEquals(s, "mock_function0");
}
@Test
public void shouldMockFunction1() {
String s = lambdaDemo.methodReference1();
assertEquals(s, "mock_function1");
}
@Test
public void shouldMockFunction2() {
String s = lambdaDemo.methodReference2();
assertEquals(s, "mock_function2");
}
@Test
public void shouldMockFunction1Throws() {
String s = lambdaDemo.methodReferenceThrows();
assertEquals(s, "mock_function1Throwable");
}
@Test
public void shouldMockStaticFunction1() {
String s = lambdaDemo.staticMethodReference1();
assertEquals(s, "mock_staticFunction1");
}
}

View File

@ -19,9 +19,7 @@ abstract public class BaseClassHandler implements Opcodes {
ClassNode cn = new ClassNode();
cr.accept(cn, 0);
transform(cn);
// flag 1 was auto compute max
ClassWriter cw = new ClassWriter( ClassWriter.COMPUTE_MAXS);
ClassWriter cw = new ClassWriter( 0);
cn.accept(cw);
return cw.toByteArray();
}

View File

@ -6,17 +6,13 @@ import com.alibaba.testable.agent.util.BytecodeUtil;
import com.alibaba.testable.agent.util.ClassUtil;
import com.alibaba.testable.agent.util.MethodUtil;
import com.alibaba.testable.core.util.LogUtil;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.*;
import sun.invoke.util.Wrapper;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
@ -25,7 +21,6 @@ import static com.alibaba.testable.core.constant.ConstPool.CONSTRUCTOR;
*/
public class SourceClassHandler extends BaseClassHandler {
private AtomicInteger atomicInteger = new AtomicInteger();
private final String mockClassName;
private final List<MethodInfo> injectMethods;
private final Set<Integer> invokeOps = new HashSet<Integer>() {{
@ -56,16 +51,13 @@ public class SourceClassHandler extends BaseClassHandler {
memberInjectMethods.add(im);
}
}
resolveMethodReference(cn);
for (MethodNode m : cn.methods) {
transformMethod(m, memberInjectMethods, newOperatorInjectMethods, cn);
transformMethod(m, memberInjectMethods, newOperatorInjectMethods);
}
}
private void transformMethod(MethodNode mn, Set<MethodInfo> memberInjectMethods,
Set<MethodInfo> newOperatorInjectMethods, ClassNode cn) {
Set<MethodInfo> newOperatorInjectMethods) {
LogUtil.verbose(" Found method %s", mn.name);
if (mn.name.startsWith("$")) {
// skip methods e.g. "$jacocoInit"
@ -119,7 +111,6 @@ public class SourceClassHandler extends BaseClassHandler {
}
}
}
i++;
} while (i < instructions.length);
}
@ -336,234 +327,4 @@ public class SourceClassHandler extends BaseClassHandler {
return Opcodes.INVOKEVIRTUAL == opcode && ClassUtil.isCompanionClassName(ownerClass);
}
private void setFinalValue(Field ownerField, Object obj, Object value) throws Exception {
ownerField.setAccessible(true);
Field modifiersField = Field.class.getDeclaredField("modifiers");
modifiersField.setAccessible(true);
modifiersField.setInt(ownerField, ownerField.getModifiers() & ~Modifier.FINAL);
ownerField.set(obj, value);
}
private List<Handle> fetchInvokeDynamicHandle(MethodNode mn) {
List<Handle> handleList = new ArrayList<Handle>();
for (AbstractInsnNode instruction : mn.instructions) {
if (instruction.getOpcode() == Opcodes.INVOKEDYNAMIC) {
InvokeDynamicInsnNode invokeDynamicInsnNode = (InvokeDynamicInsnNode)instruction;
handleList.add((Handle) invokeDynamicInsnNode.bsmArgs[1]);
}
}
return handleList;
}
private void resolveMethodReference(ClassNode cn) {
List<Handle> invokeDynamicList = new ArrayList<Handle>();
for (MethodNode method : cn.methods) {
List<Handle> handleList = fetchInvokeDynamicHandle(method);
invokeDynamicList.addAll(handleList);
}
// process for method reference
for (Handle handle : invokeDynamicList) {
if (handle.getName().startsWith("lambda$")) {
continue;
}
int tag = handle.getTag();
if (tag == Opcodes.H_NEWINVOKESPECIAL) {
// lambda new method reference
continue;
}
boolean isStatic = tag == Opcodes.H_INVOKESTATIC;
String desc = handle.getDesc();
String parameters = desc.substring(desc.indexOf("(") + 1, desc.lastIndexOf(")"));
String returnType = desc.substring(desc.indexOf(")") + 1);
String[] parameterArray = parameters.split(";");
int len = parameterArray.length;
for (String s : parameterArray) {
if (s.isEmpty()) {
len--;
}
}
String[] refineParameterArray = new String[len];
int index = 0;
for (String s : parameterArray) {
if (!s.isEmpty()) {
refineParameterArray[index] = s;
index++;
}
}
String lambdaName = String.format("Lambda$_%s_%d", handle.getName(), atomicInteger.incrementAndGet());
MethodVisitor mv = cn.visitMethod(isStatic ? ACC_PUBLIC + ACC_STATIC : ACC_PUBLIC, lambdaName, desc, null, null);
mv.visitCode();
Label l0 = new Label();
mv.visitLabel(l0);
if (!isStatic) {
mv.visitVarInsn(ALOAD, 0);
}
for (int i = 0; i < refineParameterArray.length; i++) {
String arg = refineParameterArray[i];
mv.visitVarInsn(getLoadType(arg), isStatic ? i : i + 1);
}
mv.visitMethodInsn(isStatic ? INVOKESTATIC : INVOKEVIRTUAL/*INVOKESPECIAL*/, handle.getOwner(), handle.getName(), desc, false);
mv.visitInsn(getReturnType(returnType));
Label l1 = new Label();
mv.visitLabel(l1);
String localVarOwner = handle.getOwner();
if (isStatic) {
for (int i = 0; i < refineParameterArray.length; i++) {
String localVar = refineParameterArray[i];
if (!isPrimitive(localVar)) {
localVar = localVar.endsWith(";") ? localVar : localVar + ";";
}
if (localVar.isEmpty()) {
continue;
}
mv.visitLocalVariable(String.format("o%d", i), localVar, null, l0, l1, i);
}
} else {
mv.visitLocalVariable("this", "L" + localVarOwner + ";", null, l0, l1, 0);
for (int i = 0; i < refineParameterArray.length; i++) {
String localVar = refineParameterArray[i];
if (!isPrimitive(localVar) && !isPrimitiveArray(localVar)) {
localVar = localVar.endsWith(";") ? localVar : localVar + ";";
}
if (localVar.isEmpty()) {
continue;
}
mv.visitLocalVariable(String.format("o%d", i), localVar, null, l0, l1, i + 1);
}
}
// auto compute max
mv.visitMaxs(-1, -1);
mv.visitEnd();
try {
setFinalValue(handle.getClass().getDeclaredField("name"), handle, lambdaName);
if (!handle.getOwner().equals(cn.name) && isStatic) {
setFinalValue(handle.getClass().getDeclaredField("owner"), handle, cn.name);
}
} catch (Exception ignore) {
}
}
}
@SuppressWarnings("BooleanMethodIsAlwaysInverted")
private boolean isPrimitive(String type) {
if (type.endsWith(";")) {
type = type.substring(0, type.length() - 1);
}
return BasicType.basicType(type.charAt(0)).isPrimitive();
}
private boolean isPrimitiveArray(String type) {
if (!type.startsWith("[")) {
return false;
}
if (type.endsWith(";")) {
type = type.substring(0, type.length() - 1);
}
type = type.replace("[", "");
return BasicType.basicType(type.charAt(0)).isPrimitive();
}
private int getReturnType(String returnType) {
return BasicType.basicType(returnType.charAt(0)).returnInsn;
}
private int getLoadType(String arg) {
return BasicType.basicType(arg.charAt(0)).loadVarInsn;
}
/**
* copy from java.lang.invoke.LambdaForm.BasicType
*/
enum BasicType {
/**
* all reference types
*/
L_TYPE('L', Object.class, Wrapper.OBJECT, ALOAD, ARETURN),
/**
* all primitive types
*/
I_TYPE('I', int.class, Wrapper.INT, ILOAD, IRETURN),
J_TYPE('J', long.class, Wrapper.LONG, LLOAD, LRETURN),
F_TYPE('F', float.class, Wrapper.FLOAT, FLOAD, FRETURN),
D_TYPE('D', double.class, Wrapper.DOUBLE, DLOAD, DRETURN),
V_TYPE('V', void.class, Wrapper.VOID, null, RETURN),
A_TYPE('[', Object[].class, Wrapper.OBJECT, ALOAD, ARETURN);
private final char btChar;
private final Class<?> btClass;
private final Wrapper btWrapper;
private final Integer loadVarInsn;
private final Integer returnInsn;
BasicType(char btChar, Class<?> btClass, Wrapper btWrapper, Integer loadVarInsn, Integer returnInsn) {
this.btChar = btChar;
this.btClass = btClass;
this.btWrapper = btWrapper;
this.loadVarInsn = loadVarInsn;
this.returnInsn = returnInsn;
}
public char getBtChar() {
return btChar;
}
public Class<?> getBtClass() {
return btClass;
}
public Integer getLoadVarInsn() {
return loadVarInsn;
}
public Integer getReturnInsn() {
return returnInsn;
}
public Wrapper getBtWrapper() {
return btWrapper;
}
public boolean isPrimitive() {
return this != L_TYPE && this != A_TYPE;
}
static BasicType basicType(char type) {
switch (type) {
case 'L': return L_TYPE;
case 'I': return I_TYPE;
case 'J': return J_TYPE;
case 'F': return F_TYPE;
case 'D': return D_TYPE;
case 'V': return V_TYPE;
case '[': return A_TYPE;
// all subword types are represented as ints
case 'Z':
case 'B':
case 'S':
case 'C':
return I_TYPE;
default:
throw new InternalError("Unknown type char: '"+type+"'");
}
}
}
}