From 89af22ddc21d657a2d1bcac0bde1a131fe540701 Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Thu, 23 Nov 2017 14:44:19 +0100
Subject: [PATCH] NFI Context can deal with being accessed from multiple
 threads

There is fast-path implementation that handles single thread case until
another thread appears in which case we transfer to interpreter, invalidate,
and fallback to more generic implementation.
---
 .../r/ffi/impl/nfi/TruffleNFI_Context.java    | 40 +++++++++++++++----
 .../r/runtime/context/TruffleRLanguage.java   |  7 ----
 2 files changed, 33 insertions(+), 14 deletions(-)

diff --git a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Context.java b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Context.java
index 98e0686ee8..0c2e4e6184 100644
--- a/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Context.java
+++ b/com.oracle.truffle.r.ffi.impl/src/com/oracle/truffle/r/ffi/impl/nfi/TruffleNFI_Context.java
@@ -186,13 +186,14 @@ final class TruffleNFI_Context extends RFFIContext {
         }
     }
 
-    private void initCallbacksAddress() {
+    @TruffleBoundary
+    private long initCallbacksAddress() {
         // get the address of the native thread local
         try {
             Node bind = Message.createInvoke(1).createNode();
             Node executeNode = Message.createExecute(1).createNode();
             TruffleObject getCallbacksAddressFunction = (TruffleObject) ForeignAccess.sendInvoke(bind, DLL.findSymbol("Rinternals_getCallbacksAddress", null).asTruffleObject(), "bind", "(): sint64");
-            callbacksAddress = (long) ForeignAccess.sendExecute(executeNode, getCallbacksAddressFunction);
+            return (long) ForeignAccess.sendExecute(executeNode, getCallbacksAddressFunction);
         } catch (InteropException ex) {
             throw RInternalError.shouldNotReachHere(ex);
         }
@@ -222,22 +223,47 @@ final class TruffleNFI_Context extends RFFIContext {
     }
 
     private long callbacks;
+    @CompilationFinal private boolean singleThreadOnly = true;
+    @CompilationFinal private long callbacksAddressThread;
     @CompilationFinal private long callbacksAddress;
+    private long lastCallbacksAddressThread;
+    private long lastCallbacksAddress;
 
     private long pushCallbacks() {
         if (callbacksAddress == 0) {
             CompilerDirectives.transferToInterpreterAndInvalidate();
-            initCallbacksAddress();
+            callbacksAddress = initCallbacksAddress();
+            callbacksAddressThread = Thread.currentThread().getId();
         }
-        long oldCallbacks = UnsafeAdapter.UNSAFE.getLong(callbacksAddress);
         assert callbacks != 0L;
-        assert callbacksAddress != 0L;
-        UnsafeAdapter.UNSAFE.putLong(callbacksAddress, callbacks);
+        if (singleThreadOnly && callbacksAddressThread == Thread.currentThread().getId()) {
+            // Fast path for contexts used only from a single thread
+            long oldCallbacks = UnsafeAdapter.UNSAFE.getLong(callbacksAddress);
+            assert callbacksAddress != 0L;
+            UnsafeAdapter.UNSAFE.putLong(callbacksAddress, callbacks);
+            return oldCallbacks;
+        }
+        // Slow path: cache the address, but reinitialize it if the thread has changed, without
+        // transfer to interpreter this time.
+        boolean reinitialize = singleThreadOnly || lastCallbacksAddressThread != Thread.currentThread().getId();
+        if (singleThreadOnly) {
+            CompilerDirectives.transferToInterpreterAndInvalidate();
+            singleThreadOnly = false;
+        }
+        if (reinitialize) {
+            lastCallbacksAddress = initCallbacksAddress();
+            lastCallbacksAddressThread = Thread.currentThread().getId();
+        }
+        long oldCallbacks = UnsafeAdapter.UNSAFE.getLong(lastCallbacksAddress);
+        assert lastCallbacksAddress != 0L;
+        assert lastCallbacksAddressThread == Thread.currentThread().getId();
+        UnsafeAdapter.UNSAFE.putLong(lastCallbacksAddress, callbacks);
         return oldCallbacks;
     }
 
     private void popCallbacks(long beforeValue) {
-        assert UnsafeAdapter.UNSAFE.getLong(callbacksAddress) == callbacks : "invalid nesting of native calling contexts";
+        assert !singleThreadOnly || UnsafeAdapter.UNSAFE.getLong(callbacksAddress) == callbacks : "invalid nesting of native calling contexts";
+        assert singleThreadOnly || UnsafeAdapter.UNSAFE.getLong(lastCallbacksAddress) == callbacks : "invalid nesting of native calling contexts";
         UnsafeAdapter.UNSAFE.putLong(callbacksAddress, beforeValue);
     }
 
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/context/TruffleRLanguage.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/context/TruffleRLanguage.java
index 93004fa0a9..5d31641713 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/context/TruffleRLanguage.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/context/TruffleRLanguage.java
@@ -30,11 +30,4 @@ import com.oracle.truffle.r.runtime.data.RFunction;
 public abstract class TruffleRLanguage extends TruffleLanguage<RContext> {
 
     public abstract HashMap<String, RFunction> getBuiltinFunctionCache();
-
-    @Override
-    protected boolean isThreadAccessAllowed(Thread thread, boolean singleThreaded) {
-        // FastR does not support access to a single context from multiple threads, mainly because
-        // it has to maintain thread local variables on the native side.
-        return Thread.currentThread() == thread;
-    }
 }
-- 
GitLab