From fc787f044bc7817fdb3fb2d47d22321350b67683 Mon Sep 17 00:00:00 2001
From: Florian Angerer <florian.angerer@oracle.com>
Date: Wed, 18 Oct 2017 12:51:26 +0200
Subject: [PATCH] Implemented internal function 'sockSelect'.

---
 .../r/nodes/builtin/base/BasePackage.java     |  3 ++
 .../builtin/base/ConnectionFunctions.java     | 37 +++++++++++++++++
 .../com/oracle/truffle/r/runtime/RError.java  |  4 +-
 .../r/runtime/conn/SocketConnections.java     | 41 +++++++++++++++++++
 4 files changed, 84 insertions(+), 1 deletion(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java
index 918b90d8a7..c344eb3dac 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java
@@ -36,6 +36,8 @@ import com.oracle.truffle.r.nodes.binary.BinaryBooleanNodeGen;
 import com.oracle.truffle.r.nodes.binary.BinaryBooleanScalarNodeGen;
 import com.oracle.truffle.r.nodes.binary.BinaryBooleanSpecial;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinPackage;
+import com.oracle.truffle.r.nodes.builtin.base.ConnectionFunctions.SockSelect;
+import com.oracle.truffle.r.nodes.builtin.base.ConnectionFunctionsFactory.SockSelectNodeGen;
 import com.oracle.truffle.r.nodes.builtin.base.DebugFunctions.FastRSetBreakpoint;
 import com.oracle.truffle.r.nodes.builtin.base.DebugFunctionsFactory.FastRSetBreakpointNodeGen;
 import com.oracle.truffle.r.nodes.builtin.base.fastpaths.AssignFastPathNodeGen;
@@ -676,6 +678,7 @@ public class BasePackage extends RBuiltinPackage {
         add(SinkFunctions.Sink.class, SinkFunctionsFactory.SinkNodeGen::create);
         add(SinkFunctions.SinkNumber.class, SinkFunctionsFactory.SinkNumberNodeGen::create);
         add(Slot.class, SlotNodeGen::create);
+        add(SockSelect.class, SockSelectNodeGen::create);
         add(SortFunctions.PartialSort.class, SortFunctionsFactory.PartialSortNodeGen::create);
         add(SortFunctions.QSort.class, SortFunctionsFactory.QSortNodeGen::create);
         add(SortFunctions.RadixSort.class, SortFunctionsFactory.RadixSortNodeGen::create);
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java
index 0a8f312ca4..d9d8116275 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java
@@ -30,6 +30,7 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalTrue;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.lte;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notEmpty;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
@@ -1401,4 +1402,40 @@ public abstract class ConnectionFunctions {
             }
         }
     }
+
+    @RBuiltin(name = "sockSelect", kind = RBuiltinKind.INTERNAL, parameterNames = {"socklist", "write", "timeout"}, behavior = IO)
+    public abstract static class SockSelect extends RBuiltinNode.Arg3 {
+
+        static {
+            Casts casts = new Casts(SockSelect.class);
+            casts.arg("socklist").defaultError(Message.NOT_A_LIST_OF_SOCKETS).mustNotBeMissing().mustNotBeNull().asIntegerVector().findFirst();
+            casts.arg("write").mustNotBeMissing().mustBe(logicalValue()).asLogicalVector().findFirst().map(toBoolean());
+            casts.arg("timeout").mustNotBeMissing().asIntegerVector().findFirst();
+        }
+
+        @Specialization
+        protected RLogicalVector selectMultiple(RAbstractIntVector socklist, boolean write, int timeout) {
+            RSocketConnection[] socketConnections = getSocketConnections(socklist);
+            try {
+                byte[] selected = RSocketConnection.select(socketConnections, write, timeout * 1000L);
+                return RDataFactory.createLogicalVector(selected, true);
+            } catch (IOException e) {
+                throw error(RError.Message.GENERIC, e.getMessage());
+            }
+        }
+
+        @TruffleBoundary
+        private RSocketConnection[] getSocketConnections(RAbstractIntVector socklist) {
+            RSocketConnection[] socketConnections = new RSocketConnection[socklist.getLength()];
+            for (int i = 0; i < socklist.getLength(); i++) {
+                BaseRConnection baseConnection = getBaseConnection(RConnection.fromIndex(socklist.getDataAt(i)));
+                if (baseConnection instanceof RSocketConnection) {
+                    socketConnections[i] = (RSocketConnection) baseConnection;
+                } else {
+                    throw error(Message.NOT_A_SOCKET_CONNECTION);
+                }
+            }
+            return socketConnections;
+        }
+    }
 }
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java
index e8d02437ed..d8cde8b181 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java
@@ -904,7 +904,9 @@ public final class RError extends RuntimeException implements TruffleException {
         LIST_NO_VALID_NAMES("list argument has no valid names"),
         VALUES_MUST_BE_LENGTH("values must be length %s,\n but FUN(X[[%d]]) result is length %s"),
         OS_REQUEST_LOCALE("OS reports request to set locale to \"%s\" cannot be honored"),
-        INVALID_TYPE("invalid type (%s) for '%s' (must be a %s)");
+        INVALID_TYPE("invalid type (%s) for '%s' (must be a %s)"),
+        NOT_A_LIST_OF_SOCKETS("not a list of sockets"),
+        NOT_A_SOCKET_CONNECTION("not a socket connection");
 
         public final String message;
         final boolean hasArgs;
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/SocketConnections.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/SocketConnections.java
index ae41e53ce7..6810fad8f8 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/SocketConnections.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/SocketConnections.java
@@ -26,10 +26,15 @@ import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.net.Socket;
 import java.nio.channels.ByteChannel;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
 import java.nio.channels.ServerSocketChannel;
 import java.nio.channels.SocketChannel;
+import java.util.HashMap;
+import java.util.Set;
 
 import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
+import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.conn.ConnectionSupport.AbstractOpenMode;
 import com.oracle.truffle.r.runtime.conn.ConnectionSupport.BaseRConnection;
 import com.oracle.truffle.r.runtime.conn.ConnectionSupport.ConnectionClass;
@@ -76,6 +81,36 @@ public class SocketConnections {
         public String getSummaryDescription() {
             return (server ? "<-" : "->") + host + ":" + port;
         }
+
+        @TruffleBoundary
+        public static byte[] select(RSocketConnection[] socketConnections, boolean write, long timeout) throws IOException {
+            int op = write ? SelectionKey.OP_WRITE : SelectionKey.OP_READ;
+
+            HashMap<RSocketConnection, SelectionKey> table = new HashMap<>();
+            Selector selector = Selector.open();
+            for (RSocketConnection con : socketConnections) {
+                con.checkOpen();
+
+                SocketChannel sc = (SocketChannel) con.theConnection.getChannel();
+                sc.configureBlocking(false);
+                table.put(con, sc.register(selector, op));
+            }
+            int select;
+            if (timeout >= 0) {
+                select = selector.select(timeout);
+            } else {
+                select = selector.select();
+            }
+
+            byte[] result = new byte[socketConnections.length];
+            if (select > 0) {
+                Set<SelectionKey> selectedKeys = selector.selectedKeys();
+                for (int i = 0; i < result.length; i++) {
+                    result[i] = RRuntime.asLogical(selectedKeys.contains(table.get(socketConnections[i])));
+                }
+            }
+            return result;
+        }
     }
 
     private abstract static class RSocketReadWriteConnection extends DelegateReadWriteRConnection {
@@ -170,6 +205,12 @@ public class SocketConnections {
             super.close();
             connectionSocket.close();
         }
+
+        @Override
+        public ByteChannel getChannel() {
+            return connectionSocket;
+        }
+
     }
 
     private static class RClientSocketConnection extends RSocketReadWriteConnection {
-- 
GitLab