From 3ff79c88dc1ace2734f6b22e242948d4fba075e8 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Mon, 12 Dec 2016 11:15:33 +0100
Subject: [PATCH] store handler and restart stacks in frame slots only on
 demand

---
 .../builtin/base/ConditionFunctions.java      | 41 ++++++++++++++--
 .../function/FunctionDefinitionNode.java      | 49 ++++++++++++++++---
 .../truffle/r/runtime/RErrorHandling.java     | 10 ++++
 .../r/runtime/env/frame/RFrameSlot.java       | 10 +++-
 .../truffle/r/test/ExpectedTestOutput.test    | 13 +++++
 .../library/base/TestConditionHandling.java   |  3 ++
 6 files changed, 114 insertions(+), 12 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConditionFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConditionFunctions.java
index 3957c5a02a..14d191a2e9 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConditionFunctions.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConditionFunctions.java
@@ -21,9 +21,14 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
 
 import com.oracle.truffle.api.CompilerDirectives;
 import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
+import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
+import com.oracle.truffle.api.frame.FrameSlot;
+import com.oracle.truffle.api.frame.FrameSlotTypeException;
+import com.oracle.truffle.api.frame.VirtualFrame;
 import com.oracle.truffle.r.nodes.builtin.CastBuilder;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
+import com.oracle.truffle.r.nodes.function.FunctionDefinitionNode;
 import com.oracle.truffle.r.nodes.function.PromiseHelperNode;
 import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RErrorHandling;
@@ -67,15 +72,31 @@ public class ConditionFunctions {
             return getHandlerStack();
         }
 
+        protected FrameSlot createHandlerFrameSlot(VirtualFrame frame) {
+            return ((FunctionDefinitionNode) getRootNode()).getHandlerFrameSlot(frame);
+        }
+
         @Specialization
-        @TruffleBoundary
-        protected Object addCondHands(RAbstractStringVector classes, RList handlers, REnvironment parentEnv, Object target, byte calling) {
+        protected Object addCondHands(VirtualFrame frame, RAbstractStringVector classes, RList handlers, REnvironment parentEnv, Object target, byte calling,
+                        @Cached("createHandlerFrameSlot(frame)") FrameSlot handlerFrameSlot) {
             if (classes.getLength() != handlers.getLength()) {
+                CompilerDirectives.transferToInterpreter();
                 throw RError.error(this, RError.Message.BAD_HANDLER_DATA);
             }
-            return RErrorHandling.createHandlers(classes, handlers, parentEnv, target, calling);
+            try {
+                if (!frame.isObject(handlerFrameSlot) || frame.getObject(handlerFrameSlot) == null) {
+                    frame.setObject(handlerFrameSlot, RErrorHandling.getHandlerStack());
+                }
+            } catch (FrameSlotTypeException e) {
+                throw RInternalError.shouldNotReachHere();
+            }
+            return createHandlers(classes, handlers, parentEnv, target, calling);
         }
 
+        @TruffleBoundary
+        private static Object createHandlers(RAbstractStringVector classes, RList handlers, REnvironment parentEnv, Object target, byte calling) {
+            return RErrorHandling.createHandlers(classes, handlers, parentEnv, target, calling);
+        }
     }
 
     @RBuiltin(name = ".resetCondHands", visibility = OFF, kind = INTERNAL, parameterNames = {"stack"}, behavior = COMPLEX)
@@ -108,9 +129,21 @@ public class ConditionFunctions {
             restart(casts);
         }
 
+        protected FrameSlot createRestartFrameSlot(VirtualFrame frame) {
+            return ((FunctionDefinitionNode) getRootNode()).getRestartFrameSlot(frame);
+        }
+
         @Specialization
-        protected Object addRestart(RList restart) {
+        protected Object addRestart(VirtualFrame frame, RList restart,
+                        @Cached("createRestartFrameSlot(frame)") FrameSlot restartFrameSlot) {
             checkLength(restart);
+            try {
+                if (!frame.isObject(restartFrameSlot) || frame.getObject(restartFrameSlot) == null) {
+                    frame.setObject(restartFrameSlot, RErrorHandling.getRestartStack());
+                }
+            } catch (FrameSlotTypeException e) {
+                throw RInternalError.shouldNotReachHere();
+            }
             RErrorHandling.addRestart(restart);
             return RNull.instance;
         }
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java
index 4744ee9b83..e2019c4c78 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/FunctionDefinitionNode.java
@@ -25,9 +25,11 @@ package com.oracle.truffle.r.nodes.function;
 import java.util.ArrayList;
 import java.util.List;
 
+import com.oracle.truffle.api.Assumption;
 import com.oracle.truffle.api.CompilerDirectives;
 import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
 import com.oracle.truffle.api.RootCallTarget;
+import com.oracle.truffle.api.Truffle;
 import com.oracle.truffle.api.frame.FrameDescriptor;
 import com.oracle.truffle.api.frame.FrameSlot;
 import com.oracle.truffle.api.frame.FrameSlotKind;
@@ -113,6 +115,11 @@ public final class FunctionDefinitionNode extends RRootNode implements RSyntaxNo
 
     @CompilationFinal private boolean containsDispatch;
 
+    private final Assumption noHandlerStackSlot = Truffle.getRuntime().createAssumption();
+    private final Assumption noRestartStackSlot = Truffle.getRuntime().createAssumption();
+    @CompilationFinal private FrameSlot handlerStackSlot;
+    @CompilationFinal private FrameSlot restartStackSlot;
+
     /**
      * Profiling for catching {@link ReturnException}s.
      */
@@ -229,12 +236,6 @@ public final class FunctionDefinitionNode extends RRootNode implements RSyntaxNo
 
     @Override
     public Object execute(VirtualFrame frame) {
-        /*
-         * It might be possible to only record this iff a handler is installed, by using the
-         * RArguments array.
-         */
-        Object handlerStack = RErrorHandling.getHandlerStack();
-        Object restartStack = RErrorHandling.getRestartStack();
         boolean runOnExitHandlers = true;
         try {
             verifyEnclosingAssumptions(frame);
@@ -284,7 +285,21 @@ public final class FunctionDefinitionNode extends RRootNode implements RSyntaxNo
                 argPostProcess.execute(frame);
             }
             if (runOnExitHandlers) {
-                RErrorHandling.restoreStacks(handlerStack, restartStack);
+                if (!noHandlerStackSlot.isValid() && frame.isObject(handlerStackSlot)) {
+                    try {
+                        RErrorHandling.restoreHandlerStack(frame.getObject(handlerStackSlot));
+                    } catch (FrameSlotTypeException e) {
+                        throw RInternalError.shouldNotReachHere();
+                    }
+                }
+                if (!noRestartStackSlot.isValid() && frame.isObject(restartStackSlot)) {
+                    try {
+                        RErrorHandling.restoreRestartStack(frame.getObject(restartStackSlot));
+                    } catch (FrameSlotTypeException e) {
+                        throw RInternalError.shouldNotReachHere();
+                    }
+                }
+
                 if (onExitProfile.profile(onExitSlot.hasValue(frame))) {
                     if (onExitExpressionCache == null) {
                         CompilerDirectives.transferToInterpreterAndInvalidate();
@@ -455,4 +470,24 @@ public final class FunctionDefinitionNode extends RRootNode implements RSyntaxNo
     public String getSyntaxDebugName() {
         return name;
     }
+
+    public FrameSlot getRestartFrameSlot(VirtualFrame frame) {
+        if (noRestartStackSlot.isValid()) {
+            CompilerDirectives.transferToInterpreterAndInvalidate();
+            restartStackSlot = frame.getFrameDescriptor().findOrAddFrameSlot(RFrameSlot.RestartStack);
+            noRestartStackSlot.invalidate();
+        }
+        assert restartStackSlot != null;
+        return restartStackSlot;
+    }
+
+    public FrameSlot getHandlerFrameSlot(VirtualFrame frame) {
+        if (noHandlerStackSlot.isValid()) {
+            CompilerDirectives.transferToInterpreterAndInvalidate();
+            handlerStackSlot = frame.getFrameDescriptor().findOrAddFrameSlot(RFrameSlot.HandlerStack);
+            noHandlerStackSlot.invalidate();
+        }
+        assert handlerStackSlot != null;
+        return handlerStackSlot;
+    }
 }
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RErrorHandling.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RErrorHandling.java
index af99cbf6c7..4b5b757cc6 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RErrorHandling.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RErrorHandling.java
@@ -215,6 +215,16 @@ public class RErrorHandling {
         errorHandlingState.restartStack = savedRestartStack;
     }
 
+    public static void restoreHandlerStack(Object savedHandlerStack) {
+        ContextStateImpl errorHandlingState = getRErrorHandlingState();
+        errorHandlingState.handlerStack = savedHandlerStack;
+    }
+
+    public static void restoreRestartStack(Object savedRestartStack) {
+        ContextStateImpl errorHandlingState = getRErrorHandlingState();
+        errorHandlingState.restartStack = savedRestartStack;
+    }
+
     public static Object createHandlers(RAbstractStringVector classes, RList handlers, REnvironment parentEnv, Object target, byte calling) {
         CompilerAsserts.neverPartOfCompilation();
         Object oldStack = getRestartStack();
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/env/frame/RFrameSlot.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/env/frame/RFrameSlot.java
index 81e2536a03..e303fac170 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/env/frame/RFrameSlot.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/env/frame/RFrameSlot.java
@@ -51,5 +51,13 @@ public enum RFrameSlot {
      * each call site, the value of {@link RCaller#getVisibility()} is extracted and stored into the
      * frame slot.
      */
-    Visibility;
+    Visibility,
+    /**
+     * Used to save the handler stack in frames that modify it.
+     */
+    HandlerStack,
+    /**
+     * Used to save the restart stack in frames that modify it.
+     */
+    RestartStack
 }
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
index 4db011572e..8bb019a177 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
@@ -59266,6 +59266,11 @@ f first 1
 #{f<-function(x){UseMethod("f")};f.logical<-function(x){print("logical")};f(TRUE)}
 [1] "logical"
 
+##com.oracle.truffle.r.test.library.base.TestConditionHandling.testTryCatch#
+#{ e <- simpleError("test error"); f <- function() { tryCatch(1, finally = print("Hello")); stop(e)}; f() }
+[1] "Hello"
+Error: test error
+
 ##com.oracle.truffle.r.test.library.base.TestConditionHandling.testTryCatch#
 #{ e <- simpleError("test error"); tryCatch(stop(e), error = function(e) e, finally = print("Hello"))}
 [1] "Hello"
@@ -59276,6 +59281,14 @@ f first 1
 Error: test error
 [1] "Hello"
 
+##com.oracle.truffle.r.test.library.base.TestConditionHandling.testTryCatch#
+#{ f <- function() { tryCatch(1, error = function(e) print("Hello")); stop("fred")}; f() }
+Error in f() : fred
+
+##com.oracle.truffle.r.test.library.base.TestConditionHandling.testTryCatch#
+#{ f <- function() { tryCatch(stop("fred"), error = function(e) print("Hello"))}; f() }
+[1] "Hello"
+
 ##com.oracle.truffle.r.test.library.base.TestConditionHandling.testTryCatch#
 #{ tryCatch(1, finally = print("Hello")) }
 [1] "Hello"
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConditionHandling.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConditionHandling.java
index 938ca47fe0..2aff030fb2 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConditionHandling.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConditionHandling.java
@@ -32,8 +32,11 @@ public class TestConditionHandling extends TestBase {
     public void testTryCatch() {
         assertEval("{ tryCatch(1, finally = print(\"Hello\")) }");
         assertEval("{ e <- simpleError(\"test error\"); tryCatch(stop(e), finally = print(\"Hello\")) }");
+        assertEval("{ e <- simpleError(\"test error\"); f <- function() { tryCatch(1, finally = print(\"Hello\")); stop(e)}; f() }");
         assertEval(Output.IgnoreErrorContext, "{ tryCatch(stop(\"fred\"), finally = print(\"Hello\")) }");
         assertEval("{ e <- simpleError(\"test error\"); tryCatch(stop(e), error = function(e) e, finally = print(\"Hello\"))}");
         assertEval(Ignored.Unknown, "{ tryCatch(stop(\"fred\"), error = function(e) e, finally = print(\"Hello\"))}");
+        assertEval("{ f <- function() { tryCatch(1, error = function(e) print(\"Hello\")); stop(\"fred\")}; f() }");
+        assertEval("{ f <- function() { tryCatch(stop(\"fred\"), error = function(e) print(\"Hello\"))}; f() }");
     }
 }
-- 
GitLab