diff --git a/src/main/java/de/blazemcworld/jsscripts/Mappings.java b/src/main/java/de/blazemcworld/jsscripts/Mappings.java index 5f3b1a9..1dc8cce 100644 --- a/src/main/java/de/blazemcworld/jsscripts/Mappings.java +++ b/src/main/java/de/blazemcworld/jsscripts/Mappings.java @@ -10,12 +10,18 @@ import java.io.File; import java.io.FileReader; import java.io.InputStream; +import java.lang.reflect.Field; +import java.lang.reflect.Method; import java.net.URI; import java.net.URL; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; @@ -149,6 +155,36 @@ private static MethodDef getMethod(ClassDef classDef, String namespace, String n return null; } + private static Set getOverloadedMethod(ClassDef classDef, String namespace, String name) { + Set out = new HashSet<>(); + for (MethodDef def : classDef.getMethods()) { + if (def.getName(namespace).equals(name)) { + out.add(def); + } + } + + try { + Class currentClass = Class.forName(classDef.getName(current()).replace('/', '.')); + Class parent = currentClass.getSuperclass(); + if (parent != null) { + ClassDef def = getClass(current(), parent.getName()); + if (def != null) { + out.addAll(getOverloadedMethod(def, namespace, name)); + } + } + for (Class iface : currentClass.getInterfaces()) { + ClassDef def = getClass(current(), iface.getName()); + if (def != null) { + out.addAll(getOverloadedMethod(def, namespace, name)); + } + } + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + + return out; + } + public static String current() { return current; } @@ -181,6 +217,20 @@ public static String remapMethod(String classNamespace, String className, String return res; } + public static Set remapOverloadedMethod(String classNamespace, String className, String from, String to, String name) { + ClassDef classDef = Mappings.getClass(classNamespace, className); + if (classDef == null) return new HashSet<>(Set.of(name)); + Set methodDefs = getOverloadedMethod(classDef, from, name); + if (methodDefs.size() == 0) return new HashSet<>(Set.of(name)); + Set out = new HashSet<>(); + for (MethodDef def : methodDefs) { + String res = def.getName(to); + if (res != null) out.add(res); + } + if (out.size() == 0) return new HashSet<>(Set.of(name)); + return out; + } + @SuppressWarnings("unused") public static String graalRemapClass(String clazz) { return remapClass("named", current(), clazz); @@ -192,7 +242,69 @@ public static String graalRemapField(Class clazz, String field) { } @SuppressWarnings("unused") - public static String graalRemapMethod(Class clazz, String method) { - return remapMethod(current(), clazz.getName(), "named", current(), method); + public static Object graalRemapOverloadedMethod(Object hostCtx, Class clazz, String searchName, boolean onlyStatic) { + try { + Set names = remapOverloadedMethod(current(), clazz.getName(), "named", current(), searchName); + + Class hostClassDescClass = Class.forName("com.oracle.truffle.host.HostClassDesc"); + Class hostContextClass = Class.forName("com.oracle.truffle.host.HostContext"); + Class overloadedMethodClass = Class.forName("com.oracle.truffle.host.HostMethodDesc$OverloadedMethod"); + Class singleMethodClass = Class.forName("com.oracle.truffle.host.HostMethodDesc$SingleMethod"); + Class hostMethodDescClass = Class.forName("com.oracle.truffle.host.HostMethodDesc"); + Class membersClass = Class.forName("com.oracle.truffle.host.HostClassDesc$Members"); + + Method forClassMethod = hostClassDescClass.getDeclaredMethod("forClass", hostContextClass, Class.class); + forClassMethod.setAccessible(true); + + Object hostClassDesc = forClassMethod.invoke(null, hostCtx, clazz); + + List overloads = new ArrayList<>(); + Method lookupMethodMethod = hostClassDescClass.getDeclaredMethod("lookupMethod", String.class, boolean.class); + Method lookupMethodBySignatureMethod = hostClassDescClass.getDeclaredMethod("lookupMethodBySignature", String.class, boolean.class); + Method lookupMethodByJNINameMethod = hostClassDescClass.getDeclaredMethod("lookupMethodByJNIName", String.class, boolean.class); + + lookupMethodMethod.setAccessible(true); + lookupMethodBySignatureMethod.setAccessible(true); + lookupMethodByJNINameMethod.setAccessible(true); + + Field overloadsField = overloadedMethodClass.getDeclaredField("overloads"); + overloadsField.setAccessible(true); + + for (String name : names) { + overloads.add(lookupMethodMethod.invoke(hostClassDesc, name, onlyStatic)); + overloads.add(lookupMethodBySignatureMethod.invoke(hostClassDesc, name, onlyStatic)); + overloads.add(lookupMethodByJNINameMethod.invoke(hostClassDesc, name, onlyStatic)); + } + + while (overloads.contains(null)) { + overloads.remove(null); + } + + List singleMethods = new ArrayList<>(); + + for (Object possibleMethod : overloads) { + if (singleMethodClass.isInstance(possibleMethod)) { + singleMethods.add(possibleMethod); + } else { + singleMethods.addAll(List.of((Object[]) overloadsField.get(possibleMethod))); + } + } + + if (singleMethods.size() == 0) { + return null; + } + + Method mergeMethod = membersClass.getDeclaredMethod("merge", hostMethodDescClass, hostMethodDescClass); + mergeMethod.setAccessible(true); + + Object out = singleMethods.get(0); + for (int i = 1; i < singleMethods.size(); i++) { + out = mergeMethod.invoke(null, out, singleMethods.get(i)); + } + + return out; + } catch (Exception e) { + throw new RuntimeException(e); + } } } diff --git a/src/main/java/de/blazemcworld/jsscripts/ScriptManager.java b/src/main/java/de/blazemcworld/jsscripts/ScriptManager.java index 94744d8..df5e08d 100644 --- a/src/main/java/de/blazemcworld/jsscripts/ScriptManager.java +++ b/src/main/java/de/blazemcworld/jsscripts/ScriptManager.java @@ -5,9 +5,7 @@ import com.google.gson.JsonParser; import net.fabricmc.fabric.api.client.event.lifecycle.v1.ClientLifecycleEvents; import org.objectweb.asm.Opcodes; -import org.objectweb.asm.tree.InsnList; -import org.objectweb.asm.tree.MethodInsnNode; -import org.objectweb.asm.tree.VarInsnNode; +import org.objectweb.asm.tree.*; import java.io.File; import java.io.IOException; @@ -239,14 +237,14 @@ private static void injectMappings() { }); Injector.transformMethod("com.oracle.truffle.host.HostInteropReflect", "findMethod", method -> { try { - InsnList instructions = new InsnList(); - - instructions.add(new VarInsnNode(Opcodes.ALOAD, 1)); - instructions.add(new VarInsnNode(Opcodes.ALOAD, 2)); - instructions.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "de/blazemcworld/jsscripts/Mappings", "graalRemapMethod", "(Ljava/lang/Class;Ljava/lang/String;)Ljava/lang/String;")); - instructions.add(new VarInsnNode(Opcodes.ASTORE, 2)); - - method.instructions.insertBefore(method.instructions.getFirst(), instructions); + method.instructions.clear(); + method.instructions.add(new VarInsnNode(Opcodes.ALOAD, 0)); + method.instructions.add(new VarInsnNode(Opcodes.ALOAD, 1)); + method.instructions.add(new VarInsnNode(Opcodes.ALOAD, 2)); + method.instructions.add(new VarInsnNode(Opcodes.ILOAD, 3)); + method.instructions.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "de/blazemcworld/jsscripts/Mappings", "graalRemapOverloadedMethod", "(Ljava/lang/Object;Ljava/lang/Class;Ljava/lang/String;Z)Ljava/lang/Object;")); + method.instructions.add(new TypeInsnNode(Opcodes.CHECKCAST, "com/oracle/truffle/host/HostMethodDesc")); + method.instructions.add(new InsnNode(Opcodes.ARETURN)); } catch (Exception e) { throw new RuntimeException(e); }