From 96a59edd6947d427ddbe7ae5d182f96df0710112 Mon Sep 17 00:00:00 2001
From: Florian Angerer <florian.angerer@oracle.com>
Date: Tue, 14 Nov 2017 13:26:55 +0100
Subject: [PATCH] Added support for interop message INVOKE for R calls.

---
 .../oracle/truffle/r/nodes/RASTBuilder.java   |  44 ++++-
 .../truffle/r/nodes/function/RCallNode.java   | 175 ++++++++++++++++--
 2 files changed, 202 insertions(+), 17 deletions(-)

diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java
index 1b211974b3..9d3ab2ec3e 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/RASTBuilder.java
@@ -46,6 +46,7 @@ import com.oracle.truffle.r.nodes.function.FormalArguments;
 import com.oracle.truffle.r.nodes.function.FunctionDefinitionNode;
 import com.oracle.truffle.r.nodes.function.FunctionExpressionNode;
 import com.oracle.truffle.r.nodes.function.PostProcessArgumentsNode;
+import com.oracle.truffle.r.nodes.function.RCallNode;
 import com.oracle.truffle.r.nodes.function.RCallSpecialNode;
 import com.oracle.truffle.r.nodes.function.SaveArgumentsNode;
 import com.oracle.truffle.r.nodes.function.WrapDefaultArgumentNode;
@@ -62,6 +63,8 @@ import com.oracle.truffle.r.runtime.env.frame.FrameSlotChangeMonitor;
 import com.oracle.truffle.r.runtime.nodes.EvaluatedArgumentsVisitor;
 import com.oracle.truffle.r.runtime.nodes.RCodeBuilder;
 import com.oracle.truffle.r.runtime.nodes.RNode;
+import com.oracle.truffle.r.runtime.nodes.RSyntaxCall;
+import com.oracle.truffle.r.runtime.nodes.RSyntaxConstant;
 import com.oracle.truffle.r.runtime.nodes.RSyntaxElement;
 import com.oracle.truffle.r.runtime.nodes.RSyntaxLookup;
 import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
@@ -95,12 +98,7 @@ public final class RASTBuilder implements RCodeBuilder<RSyntaxNode> {
                 switch (symbol) {
                     case "$":
                     case "@":
-                        if (args.get(1).value instanceof RSyntaxLookup) {
-                            RSyntaxLookup lookup = (RSyntaxLookup) args.get(1).value;
-                            // FastR differs from GNUR: we only use string constants to represent
-                            // field and slot lookups, while GNUR uses symbols
-                            args.set(1, RCodeBuilder.argument(args.get(1).source, args.get(1).name, constant(lookup.getLazySourceSection(), lookup.getIdentifier())));
-                        }
+                        convertSymbol(args);
                         break;
                     case "while":
                         return new WhileNode(source, lhsLookup, args.get(0).value, args.get(1).value);
@@ -147,9 +145,39 @@ public final class RASTBuilder implements RCodeBuilder<RSyntaxNode> {
             }
         }
 
+        if (canBeForeignInvoke(lhs)) {
+            return RCallNode.createCallDeferred(source, lhs.asRNode(), createSignature(args), createArguments(args));
+        }
         return RCallSpecialNode.createCall(source, lhs.asRNode(), createSignature(args), createArguments(args));
     }
 
+    /**
+     * Tests if some syntax expression can be a call in form of {@code lhsReceiver$lhsMember(args)}.
+     */
+    private static boolean canBeForeignInvoke(RSyntaxNode expr) {
+
+        if (expr instanceof RSyntaxCall) {
+            RSyntaxCall call = (RSyntaxCall) expr;
+            RSyntaxElement lhs = call.getSyntaxLHS();
+
+            if (lhs instanceof RSyntaxLookup && "$".equals(((RSyntaxLookup) lhs).getIdentifier())) {
+                RSyntaxElement[] syntaxArguments = call.getSyntaxArguments();
+                return syntaxArguments.length == 2 && isAllowedElement(syntaxArguments[0]) && isAllowedElement(syntaxArguments[1]);
+            }
+        }
+
+        return false;
+    }
+
+    private void convertSymbol(List<Argument<RSyntaxNode>> args) {
+        if (args.get(1).value instanceof RSyntaxLookup) {
+            RSyntaxLookup lookup = (RSyntaxLookup) args.get(1).value;
+            // FastR differs from GNUR: we only use string constants to represent
+            // field and slot lookups, while GNUR uses symbols
+            args.set(1, RCodeBuilder.argument(args.get(1).source, args.get(1).name, constant(lookup.getLazySourceSection(), lookup.getIdentifier())));
+        }
+    }
+
     private static ArgumentsSignature createSignature(List<Argument<RSyntaxNode>> args) {
         String[] argumentNames = args.stream().map(arg -> arg.name).toArray(String[]::new);
         ArgumentsSignature signature = ArgumentsSignature.get(argumentNames);
@@ -280,4 +308,8 @@ public final class RASTBuilder implements RCodeBuilder<RSyntaxNode> {
         }
         return ReadVariableNode.wrap(source, functionLookup ? ReadVariableNode.createForcedFunctionLookup(symbol) : ReadVariableNode.create(symbol));
     }
+
+    private static boolean isAllowedElement(RSyntaxElement e) {
+        return e instanceof RSyntaxLookup || e instanceof RSyntaxConstant;
+    }
 }
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java
index 99a29d7624..2020c013cb 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java
@@ -41,8 +41,10 @@ import com.oracle.truffle.api.frame.MaterializedFrame;
 import com.oracle.truffle.api.frame.VirtualFrame;
 import com.oracle.truffle.api.interop.ArityException;
 import com.oracle.truffle.api.interop.ForeignAccess;
+import com.oracle.truffle.api.interop.KeyInfo;
 import com.oracle.truffle.api.interop.Message;
 import com.oracle.truffle.api.interop.TruffleObject;
+import com.oracle.truffle.api.interop.UnknownIdentifierException;
 import com.oracle.truffle.api.interop.UnsupportedMessageException;
 import com.oracle.truffle.api.interop.UnsupportedTypeException;
 import com.oracle.truffle.api.nodes.ExplodeLoop;
@@ -242,6 +244,13 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
         return call.execute(frame, function, lookupVarArgs(frame), null, null);
     }
 
+    @Specialization
+    public Object callForeign(VirtualFrame frame, @SuppressWarnings("unused") DeferredFunctionValue function,
+                    @SuppressWarnings("unused") @Cached("function") DeferredFunctionValue cachedFunction,
+                    @Cached("createForeignInvoke(cachedFunction)") ForeignInvoke call) {
+        return call.execute(frame);
+    }
+
     protected RNode createDispatchArgument(int index) {
         return RContext.getASTBuilder().process(arguments[index]).asRNode();
     }
@@ -543,20 +552,37 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
         return call.execute(frame, resultFunction, new RArgsValuesAndNames(args, argsSignature), s3Args, s3DefaulArguments);
     }
 
-    protected final class ForeignCall extends Node {
+    protected abstract class ForeignCall extends Node {
 
         @Child protected CallArgumentsNode arguments;
         @Child protected Node messageNode;
-        @CompilationFinal protected int foreignCallArgCount;
         @Child protected Foreign2R foreign2RNode;
         @Child protected R2Foreign r2ForeignNode;
+        @CompilationFinal protected int foreignCallArgCount;
 
-        public ForeignCall(CallArgumentsNode arguments) {
+        protected ForeignCall(CallArgumentsNode arguments) {
             this.arguments = arguments;
         }
 
         protected Object[] evaluateArgs(VirtualFrame frame) {
-            return explicitArgs != null ? ((RArgsValuesAndNames) explicitArgs.execute(frame)).getArguments() : arguments.evaluateFlattenObjects(frame, lookupVarArgs(frame));
+            Object[] argumentsArray = explicitArgs != null ? ((RArgsValuesAndNames) explicitArgs.execute(frame)).getArguments() : arguments.evaluateFlattenObjects(frame, lookupVarArgs(frame));
+            if (r2ForeignNode == null) {
+                r2ForeignNode = insert(R2Foreign.create());
+            }
+            for (int i = 0; i < argumentsArray.length; i++) {
+                argumentsArray[i] = r2ForeignNode.execute(argumentsArray[i]);
+            }
+            return argumentsArray;
+        }
+    }
+
+    /**
+     * Calls a foreign function using message EXECUTE.
+     */
+    protected final class ForeignExecute extends ForeignCall {
+
+        protected ForeignExecute(CallArgumentsNode arguments) {
+            super(arguments);
         }
 
         protected Object execute(VirtualFrame frame, TruffleObject function) {
@@ -566,13 +592,9 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
                 messageNode = insert(Message.createExecute(argumentsArray.length).createNode());
                 foreignCallArgCount = argumentsArray.length;
                 foreign2RNode = insert(Foreign2R.create());
-                r2ForeignNode = insert(R2Foreign.create());
             }
 
             try {
-                for (int i = 0; i < argumentsArray.length; i++) {
-                    argumentsArray[i] = r2ForeignNode.execute(argumentsArray[i]);
-                }
                 Object result = ForeignAccess.sendExecute(messageNode, function, argumentsArray);
                 return foreign2RNode.execute(result);
             } catch (ArityException | UnsupportedMessageException | UnsupportedTypeException e) {
@@ -583,8 +605,59 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
         }
     }
 
-    protected ForeignCall createForeignCall() {
-        return new ForeignCall(createArguments(null, true, true));
+    /**
+     * Calls a foreign function using message INVOKE.
+     */
+    protected final class ForeignInvoke extends ForeignCall {
+
+        @Child ForeignExecute fallbackForeignExecute;
+        @CompilationFinal private boolean invoke = true;
+        private final TruffleObject receiver;
+        private final String member;
+
+        protected ForeignInvoke(TruffleObject lhsReceiver, String lhsMember, CallArgumentsNode arguments) {
+            super(arguments);
+            this.receiver = lhsReceiver;
+            this.member = lhsMember;
+        }
+
+        protected Object execute(VirtualFrame frame) {
+            Object[] argumentsArray = evaluateArgs(frame);
+            if (messageNode == null || foreignCallArgCount != argumentsArray.length) {
+                messageNode = invoke ? insert(Message.createInvoke(argumentsArray.length).createNode()) : insert(Message.createExecute(argumentsArray.length).createNode());
+                foreignCallArgCount = argumentsArray.length;
+                foreign2RNode = insert(Foreign2R.create());
+            }
+
+            try {
+                try {
+                    Object result = ForeignAccess.sendInvoke(messageNode, receiver, member, argumentsArray);
+                    return foreign2RNode.execute(result);
+                } catch (UnknownIdentifierException e) {
+                    if (invoke) {
+                        messageNode = insert(Message.createExecute(argumentsArray.length).createNode());
+                        invoke = false;
+                    }
+                    Object result = ForeignAccess.sendExecute(messageNode, receiver, argumentsArray);
+                    return foreign2RNode.execute(result);
+                }
+            } catch (ArityException | UnsupportedMessageException | UnsupportedTypeException e) {
+                CompilerDirectives.transferToInterpreter();
+                RInternalError.reportError(e);
+                throw RError.interopError(RError.findParentRBase(this), e, receiver);
+            }
+        }
+    }
+
+    protected ForeignExecute createForeignCall() {
+        return new ForeignExecute(createArguments(null, true, true));
+    }
+
+    /**
+     * Creates a foreign invoke node for a call of structure {@code lhsReceiver$lhsMember(args)}.
+     */
+    protected ForeignInvoke createForeignInvoke(DeferredFunctionValue df) {
+        return new ForeignInvoke(df.getLHSReceiver(), df.getLHSMember(), createArguments(null, true, true));
     }
 
     protected static boolean isForeignObject(Object value) {
@@ -593,7 +666,7 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
 
     @Specialization(guards = "isForeignObject(function)")
     public Object call(VirtualFrame frame, TruffleObject function,
-                    @Cached("createForeignCall()") ForeignCall foreignCall) {
+                    @Cached("createForeignCall()") ForeignExecute foreignCall) {
         return foreignCall.execute(frame, function);
     }
 
@@ -637,6 +710,14 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
         return RCallNodeGen.create(src, arguments, signature, function);
     }
 
+    /**
+     * The standard way to create a call to {@code function} with given arguments. If
+     * {@code src == RSyntaxNode.EAGER_DEPARSE} we force a deparse.
+     */
+    public static RCallNode createCallDeferred(SourceSection src, RNode function, ArgumentsSignature signature, RSyntaxNode... arguments) {
+        return RCallNodeGen.create(src, arguments, signature, new DeferredFunctionNode(function));
+    }
+
     /**
      * Creates a call that reads its explicit arguments from the frame under given identifier. This
      * allows to invoke a function with argument(s) supplied by hand. Consider using
@@ -1052,4 +1133,76 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS
     public RSyntaxElement[] getSyntaxArguments() {
         return arguments == null ? new RSyntaxElement[]{RSyntaxLookup.createDummyLookup(RSyntaxNode.LAZY_DEPARSE, "...", false)} : arguments;
     }
+
+    /**
+     * Represents the LHS of a possible foreign member call like
+     * {@code lhsReceiver$lhsMember(args)}, i.e., {@code lhsReceiver$lhsMember}.
+     */
+    protected static class DeferredFunctionValue {
+        private final TruffleObject lhsReceiver;
+        private final String lhsMember;
+
+        protected DeferredFunctionValue(TruffleObject lhsReceiver, String lhsMember) {
+            this.lhsReceiver = lhsReceiver;
+            this.lhsMember = lhsMember;
+        }
+
+        public String getLHSMember() {
+            return lhsMember;
+        }
+
+        public TruffleObject getLHSReceiver() {
+            return lhsReceiver;
+        }
+
+    }
+
+    private static class DeferredFunctionNode extends RNode {
+
+        @Child private RNode function;
+        @Child private RNode lhsReceiver;
+        @Child private RNode lhsMember;
+        @Child private Node keyInfoNode;
+
+        private final ValueProfile receiverClassProfile = ValueProfile.createClassProfile();
+        private final ValueProfile memberClassProfile = ValueProfile.createClassProfile();
+
+        public RNode getLHSMember(RNode n) {
+            return (RNode) ((RSyntaxCall) n).getSyntaxArguments()[1];
+        }
+
+        public RNode getLHSReceiver(RNode n) {
+            return (RNode) ((RSyntaxCall) n).getSyntaxArguments()[0];
+        }
+
+        protected DeferredFunctionNode(RNode function) {
+            this.lhsReceiver = getLHSReceiver(function);
+            this.lhsMember = getLHSMember(function);
+            this.function = function;
+        }
+
+        @Override
+        public Object execute(VirtualFrame frame) {
+            Object lhsReceiverObj = lhsReceiver.execute(frame);
+            if (isForeignObject(receiverClassProfile.profile(lhsReceiverObj))) {
+                Object lhsMemberObj = memberClassProfile.profile(lhsMember.execute(frame));
+                if (lhsMemberObj instanceof String) {
+                    if (keyInfoNode == null) {
+                        keyInfoNode = insert(Message.KEY_INFO.createNode());
+                    }
+                    int keyInfo = ForeignAccess.sendKeyInfo(keyInfoNode, (TruffleObject) lhsReceiverObj, (String) lhsMemberObj);
+                    if (KeyInfo.isInvocable(keyInfo)) {
+                        return new DeferredFunctionValue((TruffleObject) lhsReceiverObj, (String) lhsMemberObj);
+                    }
+                }
+            }
+            return function.execute(frame);
+        }
+
+        @Override
+        protected RSyntaxNode getRSyntaxNode() {
+            return function.asRSyntaxNode();
+        }
+
+    }
 }
-- 
GitLab