From cd464a581b722482d90554abc8cbcad890dd6f54 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Wed, 9 Nov 2016 15:23:55 +0100
Subject: [PATCH] move handling of replacement temp vars into special/generic
 replacement nodes

---
 .../control/ReplacementDispatchNode.java      |   3 +-
 .../r/nodes/control/ReplacementNode.java      | 268 +++++++++---------
 2 files changed, 131 insertions(+), 140 deletions(-)

diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java
index 689bacaeac..75a2e0de10 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementDispatchNode.java
@@ -122,7 +122,8 @@ public final class ReplacementDispatchNode extends OperatorNode {
         }
         RSyntaxLookup variable = (RSyntaxLookup) current;
         ReadVariableNode varRead = createReplacementForVariableUsing(variable, isSuper);
-        return new ReplacementNode(getLazySourceSection(), operator, varRead, lhs.asRSyntaxNode(), rhs, calls, "*rhs*" + tempNamesStartIndex, variable.getIdentifier(), isSuper, tempNamesStartIndex);
+        return ReplacementNode.create(getLazySourceSection(), operator, varRead, lhs.asRSyntaxNode(), rhs, calls, variable.getIdentifier(), isSuper,
+                        tempNamesStartIndex);
     }
 
     private static ReadVariableNode createReplacementForVariableUsing(RSyntaxLookup var, boolean isSuper) {
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java
index 2e3a23178b..f09807d82a 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/ReplacementNode.java
@@ -55,67 +55,32 @@ import com.oracle.truffle.r.runtime.nodes.RSyntaxLookup;
 import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
 
 @TypeSystemReference(EmptyTypeSystemFlatLayout.class)
-final class ReplacementNode extends OperatorNode {
+abstract class ReplacementNode extends OperatorNode {
 
-    @Child private RNode target;
-    @Child private RNode rhs;
-    @Child private WriteVariableNode targetTmpWrite;
-    @Child private RemoveAndAnswerNode targetTmpRemove;
-    @Child private WriteVariableNode targetWrite;
-    @Child private ReplacementBase replacement;
+    protected final RSyntaxElement lhs;
 
-    @Child private WriteVariableNode storeRhs;
-    @Child private RemoveAndAnswerNode removeRhs;
-
-    @Child private SetVisibilityNode visibility = SetVisibilityNode.create();
-
-    private final int tempNamesStartIndex;
-    private final List<RSyntaxCall> calls;
-    private final String rhsVarName;
-    private final RSyntaxElement lhs;
-
-    ReplacementNode(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls, String rhsVarName, String targetVarName, boolean isSuper,
-                    int tempNamesStartIndex) {
+    ReplacementNode(SourceSection source, RSyntaxLookup operator, RSyntaxElement lhs) {
         super(source, operator);
         this.lhs = lhs;
-        this.rhs = rhs;
-        this.calls = calls;
-        this.rhsVarName = rhsVarName;
-        this.tempNamesStartIndex = tempNamesStartIndex;
-        this.target = target;
-        this.targetTmpWrite = WriteVariableNode.createAnonymous(getTargetTmpName(), null, WriteVariableNode.Mode.INVISIBLE);
-        this.targetTmpRemove = RemoveAndAnswerNode.create(getTargetTmpName());
-        this.targetWrite = WriteVariableNode.createAnonymous(targetVarName, null, WriteVariableNode.Mode.INVISIBLE, isSuper);
-        this.replacement = createReplacementNode(true);
-
-        this.storeRhs = WriteVariableNode.createAnonymous(rhsVarName, null, WriteVariableNode.Mode.INVISIBLE);
-        this.removeRhs = RemoveAndAnswerNode.create(rhsVarName);
-    }
-
-    @Override
-    public Object execute(VirtualFrame frame) {
-        Object rhsValue = rhs.execute(frame);
-        storeRhs.execute(frame, rhsValue);
-        targetTmpWrite.execute(frame, target.execute(frame));
-        replacement.execute(frame);
-        targetWrite.execute(frame, targetTmpRemove.execute(frame));
-        removeRhs.execute(frame);
-        visibility.execute(frame, false);
-        return rhsValue;
     }
 
-    private ReplacementBase createReplacementNode(boolean useSpecials) {
+    public static ReplacementNode create(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls,
+                    String targetVarName, boolean isSuper, int tempNamesStartIndex) {
         CompilerAsserts.neverPartOfCompilation();
         // Note: if specials are turned off in FastR, onlySpecials will never be true
-        boolean createSpecial = hasOnlySpecialCalls() && useSpecials;
-        return createSpecial ? createSpecialReplacement() : createGenericReplacement();
+        boolean createSpecial = hasOnlySpecialCalls(calls);
+        if (createSpecial) {
+            return createSpecialReplacement(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
+        } else {
+            return createGenericReplacement(source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
+        }
     }
 
-    private String getTargetTmpName() {
+    private static String getTargetTmpName(int tempNamesStartIndex) {
         return "*tmp*" + tempNamesStartIndex;
     }
 
-    private boolean hasOnlySpecialCalls() {
+    private static boolean hasOnlySpecialCalls(List<RSyntaxCall> calls) {
         for (int i = 0; i < calls.size(); i++) {
             if (!(calls.get(i) instanceof RCallSpecialNode)) {
                 return false;
@@ -127,22 +92,41 @@ final class ReplacementNode extends OperatorNode {
     /**
      * Creates a replacement that consists only of {@link RCallSpecialNode} calls.
      */
-    private SpecialReplacementNode createSpecialReplacement() {
+    private static SpecialReplacementNode createSpecialReplacement(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls,
+                    String targetVarName, boolean isSuper, int tempNamesStartIndex) {
         CodeBuilderContext codeBuilderContext = new CodeBuilderContext(tempNamesStartIndex + 2);
-        RNode extractFunc = ReadVariableNode.create(getTargetTmpName());
+        RNode extractFunc = ReadVariableNode.create(getTargetTmpName(tempNamesStartIndex));
         for (int i = calls.size() - 1; i >= 1; i--) {
-            extractFunc = createSpecialFunctionQuery(extractFunc.asRSyntaxNode(), calls.get(i), codeBuilderContext);
+            extractFunc = createSpecialFunctionQuery(calls.get(i), extractFunc.asRSyntaxNode(), codeBuilderContext);
+            assert extractFunc instanceof RCallSpecialNode;
         }
-        RNode updateFunc = createFunctionUpdate(getLazySourceSection(), extractFunc.asRSyntaxNode(), ReadVariableNode.create(rhsVarName), calls.get(0), codeBuilderContext);
+        RNode updateFunc = createFunctionUpdate(source, extractFunc.asRSyntaxNode(), ReadVariableNode.create("*rhs*" + tempNamesStartIndex), calls.get(0), codeBuilderContext);
         assert updateFunc instanceof RCallSpecialNode : "should be only specials";
-        return new SpecialReplacementNode((RCallSpecialNode) updateFunc);
+        return new SpecialReplacementNode((RCallSpecialNode) updateFunc, source, operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex);
+    }
+
+    /**
+     * Creates a call that looks like {@code fun} but has the first argument replaced with
+     * {@code newLhs}.
+     */
+    private static RNode createSpecialFunctionQuery(RSyntaxCall fun, RSyntaxNode newFirstArg, CodeBuilderContext codeBuilderContext) {
+        RCodeBuilder<RSyntaxNode> builder = RContext.getASTBuilder();
+        RSyntaxElement[] arguments = fun.getSyntaxArguments();
+
+        RSyntaxNode[] argNodes = new RSyntaxNode[arguments.length];
+        for (int i = 0; i < arguments.length; i++) {
+            argNodes[i] = i == 0 ? newFirstArg : builder.process(arguments[i], codeBuilderContext);
+        }
+
+        return RCallSpecialNode.createCallInReplace(fun.getLazySourceSection(), builder.process(fun.getSyntaxLHS(), codeBuilderContext).asRNode(), fun.getSyntaxSignature(), argNodes).asRNode();
     }
 
     /**
      * When there are more than two function calls in LHS, then we save some function calls by
      * saving the intermediate results into temporary variables and reusing them.
      */
-    private GenericReplacementNode createGenericReplacement() {
+    private static GenericReplacementNode createGenericReplacement(SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls,
+                    String targetVarName, boolean isSuper, int tempNamesStartIndex) {
         List<RNode> instructions = new ArrayList<>();
         CodeBuilderContext codeBuilderContext = new CodeBuilderContext(tempNamesStartIndex + calls.size() + 1);
 
@@ -154,8 +138,8 @@ final class ReplacementNode extends OperatorNode {
          * 'x' in our example (the assignment from 'x' is not done in this loop).
          */
         for (int i = calls.size() - 1, tmpIndex = 0; i >= 1; i--, tmpIndex++) {
-            ReadVariableNode newLhs = ReadVariableNode.create("*tmp*" + (tempNamesStartIndex + tmpIndex));
-            RNode update = createSpecialFunctionQuery(newLhs, calls.get(i), codeBuilderContext);
+            ReadVariableNode newFirstArg = ReadVariableNode.create("*tmp*" + (tempNamesStartIndex + tmpIndex));
+            RNode update = createSpecialFunctionQuery(calls.get(i), newFirstArg, codeBuilderContext);
             instructions.add(WriteVariableNode.createAnonymous("*tmp*" + (tempNamesStartIndex + tmpIndex + 1), update, WriteVariableNode.Mode.INVISIBLE));
         }
         /*
@@ -164,32 +148,16 @@ final class ReplacementNode extends OperatorNode {
          */
         for (int i = 0; i < calls.size(); i++) {
             int tmpIndex = tempNamesStartIndex + calls.size() - i - 1;
-            String tmprName = i == 0 ? rhsVarName : "*tmpr*" + (tempNamesStartIndex + i - 1);
-            RNode update = createFunctionUpdate(getLazySourceSection(), ReadVariableNode.create("*tmp*" + tmpIndex), ReadVariableNode.create(tmprName), calls.get(i), codeBuilderContext);
+            String tmprName = i == 0 ? ("*rhs*" + tempNamesStartIndex) : ("*tmpr*" + (tempNamesStartIndex + i - 1));
+            RNode update = createFunctionUpdate(source, ReadVariableNode.create("*tmp*" + tmpIndex), ReadVariableNode.create(tmprName), calls.get(i), codeBuilderContext);
             if (i < calls.size() - 1) {
                 instructions.add(WriteVariableNode.createAnonymous("*tmpr*" + (tempNamesStartIndex + i), update, WriteVariableNode.Mode.INVISIBLE));
             } else {
-                instructions.add(WriteVariableNode.createAnonymous(getTargetTmpName(), update, WriteVariableNode.Mode.REGULAR));
+                instructions.add(WriteVariableNode.createAnonymous(getTargetTmpName(tempNamesStartIndex), update, WriteVariableNode.Mode.REGULAR));
             }
         }
 
-        return new GenericReplacementNode(instructions);
-    }
-
-    /**
-     * Creates a call that looks like {@code fun} but has the first argument replaced with
-     * {@code newLhs}.
-     */
-    private static RNode createSpecialFunctionQuery(RSyntaxNode newLhs, RSyntaxCall fun, CodeBuilderContext codeBuilderContext) {
-        RCodeBuilder<RSyntaxNode> builder = RContext.getASTBuilder();
-        RSyntaxElement[] arguments = fun.getSyntaxArguments();
-
-        RSyntaxNode[] argNodes = new RSyntaxNode[arguments.length];
-        for (int i = 0; i < arguments.length; i++) {
-            argNodes[i] = i == 0 ? newLhs : builder.process(arguments[i], codeBuilderContext);
-        }
-
-        return RCallSpecialNode.createCallInReplace(fun.getLazySourceSection(), builder.process(fun.getSyntaxLHS(), codeBuilderContext).asRNode(), fun.getSyntaxSignature(), argNodes).asRNode();
+        return new GenericReplacementNode(instructions, source, operator, target, lhs, rhs, targetVarName, isSuper, tempNamesStartIndex);
     }
 
     /**
@@ -236,115 +204,137 @@ final class ReplacementNode extends OperatorNode {
 
     static RLanguage getLanguage(WriteVariableNode wvn) {
         Node parent = wvn.getParent();
-        if (parent instanceof ReplacementBase) {
-            return RDataFactory.createLanguage(((ReplacementBase) parent).getReplacementNodeParent());
+        if (parent instanceof ReplacementNode) {
+            return RDataFactory.createLanguage((ReplacementNode) parent);
         }
         return null;
     }
 
-    /**
-     * Base class for nodes implementing the actual replacement.
-     */
-    protected abstract static class ReplacementBase extends Node {
-
-        public abstract void execute(VirtualFrame frame);
-
-        protected final ReplacementNode getReplacementNodeParent() {
-            // Note: new DSL puts another node in between ReplacementBase instance and
-            // ReplacementNode, to be flexible we traverse the parents until we reach it
-            Node current = this;
-            do {
-                current = current.getParent();
-            } while (!(current instanceof ReplacementNode));
-            return (ReplacementNode) current;
-        }
-    }
-
     /**
      * Replacement that is made of only special calls, if one of the special calls falls back to
      * full version, the replacement also falls back to {@link ReplacementNode}.
      */
-    private static final class SpecialReplacementNode extends ReplacementBase {
+    private static final class SpecialReplacementNode extends ReplacementNode {
+
+        @Child private RNode target;
+        @Child private RNode rhs;
 
+        @Child private WriteVariableNode targetTmpWrite;
+        @Child private RemoveAndAnswerNode targetTmpRemove;
+        @Child private WriteVariableNode targetWrite;
+
+        @Child private WriteVariableNode storeRhs;
+        @Child private RemoveAndAnswerNode removeRhs;
         @Child private RCallSpecialNode replaceCall;
+        @Child private SetVisibilityNode visibility = SetVisibilityNode.create();
+
+        private final List<RSyntaxCall> calls;
+        private final int tempNamesStartIndex;
+        private final boolean isSuper;
+        private final String targetVarName;
 
-        SpecialReplacementNode(RCallSpecialNode replaceCall) {
+        SpecialReplacementNode(RCallSpecialNode replaceCall, SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, List<RSyntaxCall> calls, String targetVarName,
+                        boolean isSuper, int tempNamesStartIndex) {
+            super(source, operator, lhs);
             this.replaceCall = replaceCall;
+            this.target = target;
+            this.rhs = rhs;
+            this.calls = calls;
+            this.targetVarName = targetVarName;
+            this.isSuper = isSuper;
+            this.tempNamesStartIndex = tempNamesStartIndex;
             this.replaceCall.setPropagateFullCallNeededException(true);
+            this.targetTmpWrite = WriteVariableNode.createAnonymous(getTargetTmpName(tempNamesStartIndex), null, WriteVariableNode.Mode.INVISIBLE);
+            this.targetTmpRemove = RemoveAndAnswerNode.create(getTargetTmpName(tempNamesStartIndex));
+            this.targetWrite = WriteVariableNode.createAnonymous(targetVarName, null, WriteVariableNode.Mode.INVISIBLE, isSuper);
+
+            this.storeRhs = WriteVariableNode.createAnonymous("*rhs*" + tempNamesStartIndex, null, WriteVariableNode.Mode.INVISIBLE);
+            this.removeRhs = RemoveAndAnswerNode.create("*rhs*" + tempNamesStartIndex);
         }
 
         @Override
-        public void execute(VirtualFrame frame) {
-// special++;
+        public Object execute(VirtualFrame frame) {
             try {
+                Object rhsValue = rhs.execute(frame);
+                storeRhs.execute(frame, rhsValue);
+                targetTmpWrite.execute(frame, target.execute(frame));
                 // Note: the very last call is the actual assignment, e.g. [[<-, if this call's
                 // argument is shared, it bails out. Moreover, if that call's argument is not
                 // shared, it could not be extracted from a shared container (list), so we should be
                 // OK with not calling any other update function and just update the value directly.
                 replaceCall.execute(frame);
+
+                targetWrite.execute(frame, targetTmpRemove.execute(frame));
+                removeRhs.execute(frame);
+                visibility.execute(frame, false);
+                return rhsValue;
             } catch (FullCallNeededException | RecursiveSpecialBailout e) {
                 CompilerDirectives.transferToInterpreterAndInvalidate();
-// String code = getReplacementNodeParent().source.getCode();
-// System.out.println("fallback to generic: " + code);
-// if (code.contains("season$previousSeasonalIndex.isOnAWeekday <-
-// previousSeasonalIndex.isOnAWeekday")) {
-// System.out.println("...");
-// }
-                replace(getReplacementNodeParent().createReplacementNode(false)).execute(frame);
+                return replace(createGenericReplacement(getLazySourceSection(), operator, target, lhs, rhs, calls, targetVarName, isSuper, tempNamesStartIndex)).execute(frame);
             }
         }
+
+        @Override
+        public RSyntaxElement[] getSyntaxArguments() {
+            return new RSyntaxElement[]{lhs, rhs.asRSyntaxNode()};
+        }
     }
 
     /**
      * Holds the sequence of nodes created for R's replacement assignment.
      */
-    private static final class GenericReplacementNode extends ReplacementBase {
+    private static final class GenericReplacementNode extends ReplacementNode {
+
+        @Child private RNode target;
+        @Child private RNode rhs;
+
+        @Child private WriteVariableNode targetTmpWrite;
+        @Child private RemoveAndAnswerNode targetTmpRemove;
+        @Child private WriteVariableNode targetWrite;
+
+        @Child private WriteVariableNode storeRhs;
+        @Child private RemoveAndAnswerNode removeRhs;
         @Children private final RNode[] updates;
+        @Child private SetVisibilityNode visibility = SetVisibilityNode.create();
 
-        GenericReplacementNode(List<RNode> updates) {
+        GenericReplacementNode(List<RNode> updates, SourceSection source, RSyntaxLookup operator, RNode target, RSyntaxElement lhs, RNode rhs, String targetVarName, boolean isSuper,
+                        int tempNamesStartIndex) {
+            super(source, operator, lhs);
+            this.target = target;
+            this.rhs = rhs;
             this.updates = updates.toArray(new RNode[updates.size()]);
+            this.targetTmpWrite = WriteVariableNode.createAnonymous(getTargetTmpName(tempNamesStartIndex), null, WriteVariableNode.Mode.INVISIBLE);
+            this.targetTmpRemove = RemoveAndAnswerNode.create(getTargetTmpName(tempNamesStartIndex));
+            this.targetWrite = WriteVariableNode.createAnonymous(targetVarName, null, WriteVariableNode.Mode.INVISIBLE, isSuper);
+
+            this.storeRhs = WriteVariableNode.createAnonymous("*rhs*" + tempNamesStartIndex, null, WriteVariableNode.Mode.INVISIBLE);
+            this.removeRhs = RemoveAndAnswerNode.create("*rhs*" + tempNamesStartIndex);
         }
 
         @Override
         @ExplodeLoop
-        public void execute(VirtualFrame frame) {
-// generic++;
+        public Object execute(VirtualFrame frame) {
+            Object rhsValue = rhs.execute(frame);
+            storeRhs.execute(frame, rhsValue);
+            targetTmpWrite.execute(frame, target.execute(frame));
             for (RNode update : updates) {
                 update.execute(frame);
             }
+
+            targetWrite.execute(frame, targetTmpRemove.execute(frame));
+            removeRhs.execute(frame);
+            visibility.execute(frame, false);
+            return rhsValue;
+        }
+
+        @Override
+        public RSyntaxElement[] getSyntaxArguments() {
+            return new RSyntaxElement[]{lhs, rhs.asRSyntaxNode()};
         }
     }
-//
-// private static long special;
-// private static long generic;
-//
-// static {
-// Thread t = new Thread() {
-// @Override
-// public void run() {
-// while (true) {
-// try {
-// Thread.sleep(1000);
-// } catch (InterruptedException e) {
-// e.printStackTrace();
-// }
-// System.out.println("generic/special: " + generic + " / " + special);
-// generic = 0;
-// special = 0;
-// }
-// }
-// };
-// t.setDaemon(true);
-// t.start();
-// }
 
     @Override
     public ArgumentsSignature getSyntaxSignature() {
         return ArgumentsSignature.empty(2);
     }
-
-    @Override
-    public RSyntaxElement[] getSyntaxArguments() {
-        return new RSyntaxElement[]{lhs, rhs.asRSyntaxNode()};
-    }
 }
-- 
GitLab