From ee3af28ca5fe51185984bb3f1d5413f3bdadd7eb Mon Sep 17 00:00:00 2001
From: Mick Jordan <mick.jordan@oracle.com>
Date: Sun, 7 Jun 2015 15:24:33 -0700
Subject: [PATCH] DLL thread safety

---
 .../com/oracle/truffle/r/runtime/ffi/DLL.java | 96 ++++++++++++-------
 1 file changed, 64 insertions(+), 32 deletions(-)

diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java
index 64b4e829bb..f687cd5402 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/DLL.java
@@ -33,7 +33,7 @@ import com.oracle.truffle.r.runtime.data.*;
  *
  * In general, unloading a DLL may not be possible, so the set of DLLs have to be considered VM
  * wide, in the sense of multiple {@link RContext}s. TODO what about mutable state in native code?
- * So in the case of multiple {@link RContext}s package shared libraries may be registered muliple
+ * So in the case of multiple {@link RContext}s package shared libraries may be registered multiple
  * times and we must take care not to duplicate them in the meta-data here ({@link #list}).
  *
  * logic derived from Rdynload.c
@@ -218,12 +218,12 @@ public class DLL {
         }
     }
 
-    private static final Semaphore available = new Semaphore(1, false);
+    private static final Semaphore listCritical = new Semaphore(1, false);
 
     public static DLLInfo load(String path, boolean local, boolean now) throws DLLException {
         String absPath = Utils.tildeExpand(path);
         try {
-            available.acquire();
+            listCritical.acquire();
             for (DLLInfo dllInfo : list) {
                 if (dllInfo.path.equals(absPath)) {
                     // already loaded
@@ -247,11 +247,12 @@ public class DLL {
         } catch (InterruptedException ex) {
             throw RInternalError.shouldNotReachHere();
         } finally {
-            available.release();
+            listCritical.release();
         }
     }
 
     private static final String R_INIT_PREFIX = "R_init_";
+    private static final Semaphore initCritical = new Semaphore(1, false);
 
     public static DLLInfo loadPackageDLL(String path, boolean local, boolean now) throws DLLException {
         DLLInfo dllInfo = load(path, local, now);
@@ -259,9 +260,16 @@ public class DLL {
         DLL.SymbolInfo symbolInfo = DLL.findSymbolInDLL(R_INIT_PREFIX + dllInfo.name, dllInfo);
         if (symbolInfo != null) {
             try {
-                RFFIFactory.getRFFI().getCallRFFI().invokeVoidCall(symbolInfo, new Object[]{dllInfo});
-            } catch (Throwable ex) {
-                throw new DLLException(RError.Message.DLL_RINIT_ERROR);
+                initCritical.acquire();
+                try {
+                    RFFIFactory.getRFFI().getCallRFFI().invokeVoidCall(symbolInfo, new Object[]{dllInfo});
+                } catch (Throwable ex) {
+                    throw new DLLException(RError.Message.DLL_RINIT_ERROR);
+                }
+            } catch (InterruptedException ex) {
+                throw RInternalError.shouldNotReachHere();
+            } finally {
+                initCritical.release();
             }
         }
         return dllInfo;
@@ -269,14 +277,21 @@ public class DLL {
 
     public static void unload(String path) throws DLLException {
         String absPath = Utils.tildeExpand(path);
-        for (DLLInfo info : list) {
-            if (info.path.equals(absPath)) {
-                int rc = RFFIFactory.getRFFI().getBaseRFFI().dlclose(info.handle);
-                if (rc != 0) {
-                    throw new DLLException(RError.Message.DLL_LOAD_ERROR, path, "");
+        try {
+            listCritical.acquire();
+            for (DLLInfo info : list) {
+                if (info.path.equals(absPath)) {
+                    int rc = RFFIFactory.getRFFI().getBaseRFFI().dlclose(info.handle);
+                    if (rc != 0) {
+                        throw new DLLException(RError.Message.DLL_LOAD_ERROR, path, "");
+                    }
+                    return;
                 }
-                return;
             }
+        } catch (InterruptedException ex) {
+            throw RInternalError.shouldNotReachHere();
+        } finally {
+            listCritical.release();
         }
         throw new DLLException(RError.Message.DLL_NOT_LOADED, path);
     }
@@ -299,13 +314,20 @@ public class DLL {
      */
     public static SymbolInfo findSymbolInfo(String symbol, String libName) {
         SymbolInfo symbolInfo = null;
-        for (DLLInfo dllInfo : list) {
-            if (libName == null || libName.length() == 0 || dllInfo.name.equals(libName)) {
-                symbolInfo = findSymbolInDLL(symbol, dllInfo);
-                if (symbolInfo != null) {
-                    break;
+        try {
+            listCritical.acquire();
+            for (DLLInfo dllInfo : list) {
+                if (libName == null || libName.length() == 0 || dllInfo.name.equals(libName)) {
+                    symbolInfo = findSymbolInDLL(symbol, dllInfo);
+                    if (symbolInfo != null) {
+                        break;
+                    }
                 }
             }
+        } catch (InterruptedException ex) {
+            throw RInternalError.shouldNotReachHere();
+        } finally {
+            listCritical.release();
         }
         return symbolInfo;
     }
@@ -347,29 +369,39 @@ public class DLL {
      * functions.
      */
     public static SymbolInfo findRegisteredSymbolinInDLL(String symbol, String libName) {
-        for (DLLInfo dllInfo : list) {
-            if (libName == null || libName.length() == 0 || dllInfo.name.equals(libName)) {
-                if (dllInfo.forceSymbols) {
-                    continue;
-                }
-                for (NativeSymbolType nst : NativeSymbolType.values()) {
-                    DotSymbol[] dotSymbols = dllInfo.getNativeSymbols(nst);
-                    if (dotSymbols == null) {
+        try {
+            listCritical.acquire();
+            for (DLLInfo dllInfo : list) {
+                if (libName == null || libName.length() == 0 || dllInfo.name.equals(libName)) {
+                    if (dllInfo.forceSymbols) {
                         continue;
                     }
-                    for (DotSymbol dotSymbol : dotSymbols) {
-                        if (dotSymbol.name.equals(symbol)) {
-                            return new SymbolInfo(dllInfo, symbol, dotSymbol.fun);
+                    for (NativeSymbolType nst : NativeSymbolType.values()) {
+                        DotSymbol[] dotSymbols = dllInfo.getNativeSymbols(nst);
+                        if (dotSymbols == null) {
+                            continue;
+                        }
+                        for (DotSymbol dotSymbol : dotSymbols) {
+                            if (dotSymbol.name.equals(symbol)) {
+                                return new SymbolInfo(dllInfo, symbol, dotSymbol.fun);
+                            }
                         }
                     }
-                }
 
+                }
             }
+        } catch (InterruptedException ex) {
+            throw RInternalError.shouldNotReachHere();
+        } finally {
+            listCritical.release();
         }
         return null;
     }
 
-    // Methods called from native code during library loading.
+    /*
+     * Methods called from native code during library loading. These methods are single threaded by
+     * virtue of the Semaphore in loadPackageDLL.
+     */
 
     /**
      * Upcall from native to set the routines of type denoted by {@code nstOrd}.
@@ -383,7 +415,7 @@ public class DLL {
     }
 
     /**
-     * Upcxall from native to create a {@link DotSymbol} value.
+     * Upcall from native to create a {@link DotSymbol} value.
      */
     public static DotSymbol setDotSymbolValues(String name, long fun, int numArgs) {
         return new DotSymbol(name, fun, numArgs);
-- 
GitLab