From 6f4f9eda5e99b9c4c4d4f4fcd9f66702e631fe34 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Thu, 10 Nov 2016 15:55:39 +0100
Subject: [PATCH] replacement special calls return the rhs value with the
 exception if it was already evaluated

---
 .../nodes/builtin/base/infix/UpdateField.java |  4 +-
 .../builtin/base/infix/UpdateSubscript.java   |  2 +-
 .../control/ReplacementDispatchNode.java      | 18 +++-
 .../r/nodes/control/ReplacementNode.java      | 95 ++++++++++++++++++-
 .../r/nodes/function/RCallSpecialNode.java    | 47 +++++----
 .../r/runtime/builtins/RSpecialFactory.java   | 26 ++++-
 6 files changed, 157 insertions(+), 35 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java
index 35b1f19c4a..da44b6344d 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateField.java
@@ -69,7 +69,7 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase {
     public RList doList(RList list, String field, Object value,
                     @Cached("getIndex(list.getNames(), field)") int index) {
         if (index == -1) {
-            throw RSpecialFactory.throwFullCallNeeded();
+            throw RSpecialFactory.throwFullCallNeeded(value);
         }
         updateCache(list, field);
         Object sharedValue = value;
@@ -89,7 +89,7 @@ abstract class UpdateFieldSpecial extends SpecialsUtils.ListFieldSpecialBase {
     @SuppressWarnings("unused")
     @Fallback
     public void doFallback(Object container, Object field, Object value) {
-        throw RSpecialFactory.throwFullCallNeeded();
+        throw RSpecialFactory.throwFullCallNeeded(value);
     }
 
     private ShareObjectNode getShareObjectNode() {
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java
index 4d17163f6d..5a68b3998b 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/infix/UpdateSubscript.java
@@ -128,7 +128,7 @@ abstract class UpdateSubscriptSpecial extends SubscriptSpecialCommon {
     @SuppressWarnings("unused")
     @Fallback
     protected static Object setFallback(Object vector, Object index, Object value) {
-        throw RSpecialFactory.throwFullCallNeeded();
+        throw RSpecialFactory.throwFullCallNeeded(value);
     }
 }
 
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java
index caecfb41d8..656bb60251 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java
@@ -73,15 +73,23 @@ public final class ReplacementDispatchNode extends OperatorNode {
     @Override
     public Object execute(VirtualFrame frame) {
         CompilerDirectives.transferToInterpreterAndInvalidate();
+        return create(false).execute(frame);
+    }
+
+    @Override
+    public void voidExecute(VirtualFrame frame) {
+        CompilerDirectives.transferToInterpreterAndInvalidate();
+        create(true).voidExecute(frame);
+    }
 
+    public RNode create(boolean isVoid) {
         RNode replacement;
         if (lhs instanceof RSyntaxCall) {
-            replacement = createReplacementNode();
+            replacement = createReplacementNode(isVoid);
         } else {
             replacement = new WriteVariableSyntaxNode(getLazySourceSection(), operator, lhs.asRSyntaxNode(), rhs, isSuper);
         }
-
-        return replace(replacement).execute(frame);
+        return replace(replacement);
     }
 
     @Override
@@ -94,7 +102,7 @@ public final class ReplacementDispatchNode extends OperatorNode {
         return new RSyntaxElement[]{lhs.asRSyntaxNode(), rhs.asRSyntaxNode()};
     }
 
-    private ReplacementNode createReplacementNode() {
+    private ReplacementNode createReplacementNode(boolean isVoid) {
         CompilerAsserts.neverPartOfCompilation();
 
         /*
@@ -122,7 +130,7 @@ public final class ReplacementDispatchNode extends OperatorNode {
         }
         RSyntaxLookup variable = (RSyntaxLookup) current;
         ReadVariableNode varRead = createReplacementForVariableUsing(variable, isSuper);
-        return ReplacementNode.create(getLazySourceSection(), operator, varRead, lhs.asRSyntaxNode(), rhs, calls, variable.getIdentifier(), isSuper, tempNamesStartIndex);
+        return ReplacementNode.create(getLazySourceSection(), operator, varRead, lhs.asRSyntaxNode(), rhs, calls, variable.getIdentifier(), isSuper, tempNamesStartIndex, isVoid);
     }
 
     private static ReadVariableNode createReplacementForVariableUsing(RSyntaxLookup var, boolean isSuper) {
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java
index e97d211ff5..1e0d45dfc2 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java
@@ -65,12 +65,16 @@ abstract class ReplacementNode extends OperatorNode {
     }
 
     public static ReplacementNode create(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls,
-                    String targetVarName, boolean isSuper, int tempNamesStartIndex) {
+                    String targetVarName, boolean isSuper, int tempNamesStartIndex, boolean isVoid) {
         CompilerAsserts.neverPartOfCompilation();
         // Note: if specials are turned off in FastR, onlySpecials will never be true
         boolean createSpecial = hasOnlySpecialCalls(calls);
         if (createSpecial) {
-            return new SpecialReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
+            if (isVoid) {
+                return new SpecialVoidReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
+            } else {
+                return new SpecialReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
+            }
         } else {
             return new GenericReplacementNode(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
         }
@@ -102,7 +106,7 @@ abstract class ReplacementNode extends OperatorNode {
             argNodes[i] = i == 0 ? newFirstArg : builder.process(arguments[i], codeBuilderContext);
         }
 
-        return RCallSpecialNode.createCallInReplace(fun.getLazySourceSection(), builder.process(fun.getSyntaxLHS(), codeBuilderContext).asRNode(), fun.getSyntaxSignature(), argNodes).asRNode();
+        return RCallSpecialNode.createCallInReplace(fun.getLazySourceSection(), builder.process(fun.getSyntaxLHS(), codeBuilderContext).asRNode(), fun.getSyntaxSignature(), argNodes, 0).asRNode();
     }
 
     /**
@@ -144,7 +148,7 @@ abstract class ReplacementNode extends OperatorNode {
             newArgs[1] = builder.lookup(oldArgs[1].getLazySourceSection(), ((RSyntaxLookup) oldArgs[1]).getIdentifier() + "<-", true);
             newSyntaxLHS = RCallSpecialNode.createCall(callLHS.getLazySourceSection(), ((RSyntaxNode) callLHS.getSyntaxLHS()).asRNode(), callLHS.getSyntaxSignature(), newArgs);
         }
-        return RCallSpecialNode.createCallInReplace(source, newSyntaxLHS.asRNode(), ArgumentsSignature.get(names), argNodes).asRNode();
+        return RCallSpecialNode.createCallInReplace(source, newSyntaxLHS.asRNode(), ArgumentsSignature.get(names), argNodes, 0, argNodes.length - 1).asRNode();
     }
 
     static RLanguage getLanguage(WriteVariableNode wvn) {
@@ -193,6 +197,12 @@ abstract class ReplacementNode extends OperatorNode {
             removeRhs.execute(frame);
         }
 
+        protected final void voidExecuteWithRhs(VirtualFrame frame, Object rhsValue) {
+            storeRhs.execute(frame, rhsValue);
+            executeReplacement(frame);
+            removeRhs.execute(frame);
+        }
+
         protected abstract void executeReplacement(VirtualFrame frame);
 
         @Override
@@ -231,7 +241,7 @@ abstract class ReplacementNode extends OperatorNode {
             RNode extractFunc = target;
             for (int i = calls.size() - 1; i >= 1; i--) {
                 extractFunc = createSpecialFunctionQuery(calls.get(i), extractFunc.asRSyntaxNode(), codeBuilderContext);
-                assert extractFunc instanceof RCallSpecialNode;
+                ((RCallSpecialNode) extractFunc).setPropagateFullCallNeededException(true);
             }
             this.replaceCall = (RCallSpecialNode) createFunctionUpdate(source, extractFunc.asRSyntaxNode(), ReadVariableNode.create("*rhs*" + tempNamesStartIndex), calls.get(0), codeBuilderContext);
             this.replaceCall.setPropagateFullCallNeededException(true);
@@ -252,6 +262,81 @@ abstract class ReplacementNode extends OperatorNode {
         }
     }
 
+    /**
+     * Replacement that is made of only special calls, if one of the special calls falls back to
+     * full version, the replacement also falls back to {@link ReplacementNode}. Additionally, this
+     * type only works if the result of the call (the rhs) is not needed.
+     */
+    private static final class SpecialVoidReplacementNode extends ReplacementNode {
+
+        @Child private RCallSpecialNode replaceCall;
+
+        private final RNode rhs;
+        private final List<RSyntaxCall> calls;
+        private final int tempNamesStartIndex;
+        private final boolean isSuper;
+        private final String targetVarName;
+        private final RNode target;
+
+        SpecialVoidReplacementNode(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls, String targetVarName,
+                        boolean isSuper, int tempNamesStartIndex) {
+            super(source, operator, lhs);
+            this.target = target;
+            this.rhs = rhs;
+            this.calls = calls;
+            this.targetVarName = targetVarName;
+            this.isSuper = isSuper;
+            this.tempNamesStartIndex = tempNamesStartIndex;
+
+            /*
+             * Creates a replacement that consists only of {@link RCallSpecialNode} calls.
+             */
+            CodeBuilderContext codeBuilderContext = new CodeBuilderContext(tempNamesStartIndex + 2);
+            RNode extractFunc = target;
+            for (int i = calls.size() - 1; i >= 1; i--) {
+                extractFunc = createSpecialFunctionQuery(calls.get(i), extractFunc.asRSyntaxNode(), codeBuilderContext);
+                ((RCallSpecialNode) extractFunc).setPropagateFullCallNeededException(true);
+            }
+            this.replaceCall = (RCallSpecialNode) createFunctionUpdate(source, extractFunc.asRSyntaxNode(), rhs.asRSyntaxNode(), calls.get(0), codeBuilderContext);
+            this.replaceCall.setPropagateFullCallNeededException(true);
+        }
+
+        @Override
+        public Object execute(VirtualFrame frame) {
+            CompilerDirectives.transferToInterpreterAndInvalidate();
+            GenericReplacementNode replacement = new GenericReplacementNode(getLazySourceSection(), operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
+            return replace(replacement).execute(frame);
+        }
+
+        @Override
+        public void voidExecute(VirtualFrame frame) {
+            try {
+                // Note: the very last call is the actual assignment, e.g. [[<-, if this call's
+                // argument is shared, it bails out. Moreover, if that call's argument is not
+                // shared, it could not be extracted from a shared container (list), so we should be
+                // OK with not calling any other update function and just update the value directly.
+                replaceCall.execute(frame);
+            } catch (FullCallNeededException | RecursiveSpecialBailout e) {
+                CompilerDirectives.transferToInterpreterAndInvalidate();
+                GenericReplacementNode replacement = replace(new GenericReplacementNode(getLazySourceSection(), operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex));
+
+                Object rhsValue = e instanceof FullCallNeededException ? ((FullCallNeededException) e).rhsValue : ((RecursiveSpecialBailout) e).rhsValue;
+                if (rhsValue == null) {
+                    // we haven't queried the rhs value yet
+                    replacement.voidExecute(frame);
+                } else {
+                    // rhs was already queried, so pass it along
+                    replacement.voidExecuteWithRhs(frame, rhsValue);
+                }
+            }
+        }
+
+        @Override
+        public RSyntaxElement[] getSyntaxArguments() {
+            return new RSyntaxElement[]{lhs, rhs.asRSyntaxNode()};
+        }
+    }
+
     /**
      * Holds the sequence of nodes created for R's replacement assignment.
      */
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java
index b9b3f54935..8259598aa4 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallSpecialNode.java
@@ -145,22 +145,22 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode
     }
 
     /**
-     * This passes {@code true} for the isReplacement parameter and ignores the first argument,
-     * i.e., does not modify the first argument in any way before passing it to
+     * This passes {@code true} for the isReplacement parameter and ignores the specified arguments,
+     * i.e., does not modify them in any way before passing it to
      * {@link RSpecialFactory#create(ArgumentsSignature, RNode[], boolean)}.
      */
-    public static RSyntaxNode createCallInReplace(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments) {
-        return createCall(sourceSection, functionNode, signature, arguments, true);
+    public static RSyntaxNode createCallInReplace(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, int... ignoredArguments) {
+        return createCall(sourceSection, functionNode, signature, arguments, true, ignoredArguments);
     }
 
     public static RSyntaxNode createCall(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments) {
         return createCall(sourceSection, functionNode, signature, arguments, false);
     }
 
-    private static RSyntaxNode createCall(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace) {
+    private static RSyntaxNode createCall(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace, int... ignoredArguments) {
         RCallSpecialNode special = null;
         if (useSpecials) {
-            special = tryCreate(sourceSection, functionNode, signature, arguments, inReplace);
+            special = tryCreate(sourceSection, functionNode, signature, arguments, inReplace, ignoredArguments);
         }
         if (special != null) {
             return special;
@@ -169,14 +169,17 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode
         }
     }
 
-    private static RCallSpecialNode tryCreate(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace) {
+    private static RCallSpecialNode tryCreate(SourceSection sourceSection, RNode functionNode, ArgumentsSignature signature, RSyntaxNode[] arguments, boolean inReplace, int[] ignoredArguments) {
         RSyntaxNode syntaxFunction = functionNode.asRSyntaxNode();
         if (!(syntaxFunction instanceof RSyntaxLookup)) {
             // LHS is not a simple lookup -> bail out
             return null;
         }
-        for (RSyntaxNode argument : arguments) {
-            if (!(argument instanceof RSyntaxLookup || argument instanceof RSyntaxConstant || argument instanceof RCallSpecialNode)) {
+        for (int i = 0; i < arguments.length; i++) {
+            if (contains(ignoredArguments, i)) {
+                continue;
+            }
+            if (!(arguments[i] instanceof RSyntaxLookup || arguments[i] instanceof RSyntaxConstant || arguments[i] instanceof RCallSpecialNode)) {
                 // argument is not a simple lookup or constant value or another special -> bail out
                 return null;
             }
@@ -194,10 +197,7 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode
         }
         RNode[] localArguments = new RNode[arguments.length];
         for (int i = 0; i < arguments.length; i++) {
-            if (inReplace && i == 0) {
-                if (arguments[i] instanceof RCallSpecialNode) {
-                    ((RCallSpecialNode) arguments[i]).setArgumentIndex(i);
-                }
+            if (inReplace && contains(ignoredArguments, i)) {
                 localArguments[i] = arguments[i].asRNode();
             } else {
                 if (arguments[i] instanceof RSyntaxLookup) {
@@ -222,6 +222,15 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode
         return new RCallSpecialNode(sourceSection, functionNode, expectedFunction, arguments, signature, special);
     }
 
+    private static boolean contains(int[] ignoredArguments, int index) {
+        for (int i = 0; i < ignoredArguments.length; i++) {
+            if (ignoredArguments[i] == index) {
+                return true;
+            }
+        }
+        return false;
+    }
+
     @Override
     public Object execute(VirtualFrame frame, Object function) {
         try {
@@ -232,18 +241,18 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode
             return special.execute(frame);
         } catch (RecursiveSpecialBailout bailout) {
             CompilerDirectives.transferToInterpreterAndInvalidate();
-            throwOnRecursiveSpecial();
+            throwOnRecursiveSpecial(bailout.rhsValue);
             return replace(getRCallNode(rewriteSpecialArgument(bailout))).execute(frame, function);
         } catch (RSpecialFactory.FullCallNeededException e) {
             CompilerDirectives.transferToInterpreterAndInvalidate();
-            throwOnRecursiveSpecial();
+            throwOnRecursiveSpecial(e.rhsValue);
             return replace(getRCallNode()).execute(frame, function);
         }
     }
 
-    private void throwOnRecursiveSpecial() {
+    private void throwOnRecursiveSpecial(Object rhsValue) {
         if (isRecursiveSpecial()) {
-            throw new RecursiveSpecialBailout(argumentIndex);
+            throw new RecursiveSpecialBailout(argumentIndex, rhsValue);
         }
     }
 
@@ -313,9 +322,11 @@ public final class RCallSpecialNode extends RCallBaseNode implements RSyntaxNode
     @SuppressWarnings("serial")
     public static final class RecursiveSpecialBailout extends RuntimeException {
         public final int argumentIndex;
+        public final Object rhsValue;
 
-        RecursiveSpecialBailout(int argumentIndex) {
+        RecursiveSpecialBailout(int argumentIndex, Object rhsValue) {
             this.argumentIndex = argumentIndex;
+            this.rhsValue = rhsValue;
         }
 
         @Override
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java
index d078aab5d1..f5646bc4ac 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/builtins/RSpecialFactory.java
@@ -33,9 +33,26 @@ import com.oracle.truffle.r.runtime.nodes.RNode;
  * {@link #throwFullCallNeeded()} and the it will be replaced with call to the full blown built-in.
  */
 public interface RSpecialFactory {
+
+    /**
+     * Signals that the current special call cannot fulfill the requested action because of some
+     * restriction, and that the call should be rewritten to the generic call. This function must
+     * not be used if the rhs value of an update call was already computed - in this case, use the
+     * {@link #throwFullCallNeeded(Object)} function instead.
+     */
     static FullCallNeededException throwFullCallNeeded() {
         CompilerDirectives.transferToInterpreterAndInvalidate();
-        throw FullCallNeededException.INSTANCE;
+        throw new FullCallNeededException(null);
+    }
+
+    /**
+     * Signals that the current special call cannot fulfill the requested action because of some
+     * restriction, and that the call should be rewritten to the generic call. This function must be
+     * used when a rhs value of an update call was already computed and must not be recomputed.
+     */
+    static FullCallNeededException throwFullCallNeeded(Object rhsValue) {
+        CompilerDirectives.transferToInterpreterAndInvalidate();
+        throw new FullCallNeededException(rhsValue);
     }
 
     /**
@@ -51,10 +68,11 @@ public interface RSpecialFactory {
 
     @SuppressWarnings("serial")
     final class FullCallNeededException extends RuntimeException {
-        private static RuntimeException INSTANCE = new FullCallNeededException();
 
-        private FullCallNeededException() {
-            // singleton
+        public Object rhsValue;
+
+        private FullCallNeededException(Object rhsValue) {
+            this.rhsValue = rhsValue;
         }
 
         @Override
-- 
GitLab