From 8e372019a58a270bf9cec44e325006da77ea8aa0 Mon Sep 17 00:00:00 2001
From: Adam Welc <adam.welc@oracle.com>
Date: Thu, 13 Aug 2015 20:35:04 -0700
Subject: [PATCH] Added a function that supports waiting for a context thread
 to finish.

---
 .../truffle/r/library/fastr/FastRContext.java | 23 ++++++++-
 .../library/fastr/src/NAMESPACE               |  1 +
 .../library/fastr/src/R/fastr.R               |  5 ++
 .../r/nodes/builtin/base/foreign/FastR.java   |  2 +
 .../oracle/truffle/r/runtime/RContext.java    | 51 +++++++++++++++++--
 5 files changed, 76 insertions(+), 6 deletions(-)

diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastr/FastRContext.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastr/FastRContext.java
index b5b3b0cf66..83a6bc602b 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastr/FastRContext.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/fastr/FastRContext.java
@@ -80,6 +80,27 @@ public class FastRContext {
         }
     }
 
+    public abstract static class Join extends RExternalBuiltinNode.Arg1 {
+        @Specialization
+        protected RNull eval(RIntVector contexts) {
+            try {
+                for (int i = 0; i < contexts.getLength(); i++) {
+                    RContext context = RContext.find(contexts.getDataAt(i));
+                    if (context == null) {
+                        // already done
+                        continue;
+                    } else {
+                        context.joinThread();
+                    }
+                }
+            } catch (InterruptedException ex) {
+                throw RError.error(this, RError.Message.GENERIC, "error finishing eval thread");
+
+            }
+            return RNull.instance;
+        }
+    }
+
     public abstract static class Eval extends RExternalBuiltinNode.Arg3 {
         @Specialization
         protected RNull eval(RIntVector contexts, RAbstractStringVector exprs, byte par) {
@@ -97,7 +118,7 @@ public class FastRContext {
                         threads[i].join();
                     }
                 } catch (InterruptedException ex) {
-
+                    throw RError.error(this, RError.Message.GENERIC, "error finishing eval thread");
                 }
             } else {
                 for (int i = 0; i < contexts.getLength(); i++) {
diff --git a/com.oracle.truffle.r.native/library/fastr/src/NAMESPACE b/com.oracle.truffle.r.native/library/fastr/src/NAMESPACE
index 279c0d1878..ab0090b5af 100644
--- a/com.oracle.truffle.r.native/library/fastr/src/NAMESPACE
+++ b/com.oracle.truffle.r.native/library/fastr/src/NAMESPACE
@@ -13,6 +13,7 @@ export(fastr.createpkgsources)
 export(fastr.createpkgsource)
 export(fastr.context.create)
 export(fastr.context.spawn)
+export(fastr.context.join)
 export(fastr.context.eval)
 export(fastr.context.pareval)
 export(print.fastr_context)
diff --git a/com.oracle.truffle.r.native/library/fastr/src/R/fastr.R b/com.oracle.truffle.r.native/library/fastr/src/R/fastr.R
index bea405893f..d782c83b10 100644
--- a/com.oracle.truffle.r.native/library/fastr/src/R/fastr.R
+++ b/com.oracle.truffle.r.native/library/fastr/src/R/fastr.R
@@ -90,6 +90,11 @@ fastr.context.spawn <- function(contexts, exprs) {
 	invisible(NULL)
 }
 
+fastr.context.join <- function(contexts) {
+	.FastR(.NAME="context.join", contexts)
+	invisible(NULL)
+}
+
 fastr.context.eval <- function(contexts, exprs, par=FALSE) {
 	.FastR(.NAME="context.eval", contexts, exprs, par)
 	invisible(NULL)
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FastR.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FastR.java
index 02044b5846..c8f9b77579 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FastR.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/FastR.java
@@ -88,6 +88,8 @@ public abstract class FastR extends RBuiltinNode {
                 return FastRContextFactory.PrintNodeGen.create();
             case "context.spawn":
                 return FastRContextFactory.SpawnNodeGen.create();
+            case "context.join":
+                return FastRContextFactory.JoinNodeGen.create();
             case "context.eval":
                 return FastRContextFactory.EvalNodeGen.create();
             case "fastr.channel.create":
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RContext.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RContext.java
index 868a882219..17182ac2d0 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RContext.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RContext.java
@@ -447,6 +447,7 @@ public final class RContext extends ExecutionContext {
         public EvalThread(RContext context, Source source) {
             super(context);
             this.source = source;
+            context.evalThread = this;
         }
 
         @Override
@@ -508,6 +509,11 @@ public final class RContext extends ExecutionContext {
      */
     private RContext sharedChild;
 
+    /**
+     * Back pointer to the evalThread.
+     */
+    private EvalThread evalThread;
+
     /**
      * Typically there is a 1-1 relationship between an {@link RContext} and the thread that is
      * performing the evaluation, so we can store the {@link RContext} in a {@link ThreadLocal}.
@@ -534,6 +540,8 @@ public final class RContext extends ExecutionContext {
 
     private static final Deque<RContext> allContexts = new ConcurrentLinkedDeque<>();
 
+    private static final Semaphore allContextsSemaphore = new Semaphore(1, true);
+
     /**
      * A (hopefully) temporary workaround to ignore the setting of {@link #resultVisible} for
      * benchmarks. Set across all contexts.
@@ -567,6 +575,20 @@ public final class RContext extends ExecutionContext {
         }
     }
 
+    /**
+     * Waits for the associated EvalThread to finish
+     *
+     * @throws InterruptedException
+     */
+    public void joinThread() throws InterruptedException {
+        EvalThread t = this.evalThread;
+        if (t == null) {
+            throw RError.error(RError.NO_NODE, RError.Message.GENERIC, "no eval thread in a given context");
+        }
+        this.evalThread = null;
+        t.join();
+    }
+
     private static final Assumption singleContextAssumption = Truffle.getRuntime().createAssumption("single RContext");
     @CompilationFinal private static RContext singleContext;
 
@@ -586,7 +608,13 @@ public final class RContext extends ExecutionContext {
         }
         this.consoleHandler = consoleHandler;
         this.interactive = consoleHandler.isInteractive();
-        allContexts.add(this);
+        try {
+            allContextsSemaphore.acquire();
+            allContexts.add(this);
+            allContextsSemaphore.release();
+        } catch (InterruptedException x) {
+            throw RError.error(RError.NO_NODE, RError.Message.GENERIC, "error destroying context");
+        }
 
         if (singleContextAssumption.isValid()) {
             if (singleContext == null) {
@@ -699,7 +727,13 @@ public final class RContext extends ExecutionContext {
             parent.sharedChild = null;
         }
         engine = null;
-        allContexts.remove(this);
+        try {
+            allContextsSemaphore.acquire();
+            allContexts.remove(this);
+            allContextsSemaphore.release();
+        } catch (InterruptedException x) {
+            throw RError.error(RError.NO_NODE, RError.Message.GENERIC, "error destroying context");
+        }
         if (parent == null) {
             threadLocalContext.set(null);
         } else {
@@ -720,10 +754,17 @@ public final class RContext extends ExecutionContext {
     }
 
     public static RContext find(int id) {
-        for (RContext context : allContexts) {
-            if (context.id == id) {
-                return context;
+        try {
+            allContextsSemaphore.acquire();
+            for (RContext context : allContexts) {
+                if (context.id == id) {
+                    allContextsSemaphore.release();
+                    return context;
+                }
             }
+            allContextsSemaphore.release();
+        } catch (InterruptedException x) {
+            throw RError.error(RError.NO_NODE, RError.Message.GENERIC, "error destroying context");
         }
         return null;
     }
-- 
GitLab