From 09e1aae72534c1f5192d30b29c73249539c8d5e9 Mon Sep 17 00:00:00 2001
From: Adam Welc <adam.welc@oracle.com>
Date: Wed, 19 Aug 2015 10:20:13 -0700
Subject: [PATCH] Optimizations for passing lists through channels.

---
 .../oracle/truffle/r/runtime/RChannel.java    | 93 +++++++++++++++++--
 .../r/test/library/fastr/TestChannels.java    |  3 +-
 2 files changed, 87 insertions(+), 9 deletions(-)

diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RChannel.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RChannel.java
index dad360bda1..c609d74022 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RChannel.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RChannel.java
@@ -25,8 +25,11 @@ package com.oracle.truffle.r.runtime;
 import java.io.*;
 import java.util.concurrent.*;
 
+import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
+import com.oracle.truffle.r.runtime.conn.*;
 import com.oracle.truffle.r.runtime.data.*;
 import com.oracle.truffle.r.runtime.data.model.*;
+import com.oracle.truffle.r.runtime.env.*;
 
 /**
  * Implementation of a channel abstraction used for communication between parallel contexts in
@@ -138,19 +141,74 @@ public class RChannel {
         }
     }
 
-    public static void send(int id, Object data) {
-        Object msg = data;
-        RChannel channel = getChannelFromId(id);
-        if ((msg instanceof RAbstractVector && !(msg instanceof RList)) || msg instanceof RDataFrame || msg instanceof RFactor) {
-            // make sure that what's passed through the channel will be copied on the first
-            // update
-            RShareable shareable = (RShareable) msg;
+    private static class SerializedList {
+
+        private RList list;
+
+        public SerializedList(RList list) {
+            this.list = list;
+        }
+
+        public RList getList() {
+            return list;
+        }
+    }
+
+    public static void makeShared(Object o) {
+        if (o instanceof RShareable) {
+            RShareable shareable = (RShareable) o;
             if (FastROptions.NewStateTransition) {
                 shareable.incRefCount();
                 shareable.incRefCount();
             } else {
                 shareable.makeShared();
             }
+        }
+    }
+
+    private static Object convertPrivate(Object o) throws IOException {
+        if (o instanceof RList) {
+            RList list = (RList) o;
+            return createShareable(list);
+        } else if (!(o instanceof RFunction || o instanceof REnvironment || o instanceof RConnection || o instanceof RLanguage)) {
+            // TODO: should we make internal values shareable?
+            return o;
+        } else {
+            return RSerialize.serialize(o, false, true, RSerialize.DEFAULT_VERSION, null);
+        }
+    }
+
+    @TruffleBoundary
+    private static Object createShareable(RList list) throws IOException {
+        RList newList = list;
+        for (int i = 0; i < list.getLength(); i++) {
+            Object el = list.getDataAt(i);
+            Object newEl = convertPrivate(el);
+            if (el != newEl) {
+                // conversion happened update element
+                if (list == newList) {
+                    // create a shallow copy
+                    newList = (RList) list.copy();
+                }
+                newList.updateDataAt(i, newEl, null);
+            }
+        }
+        return list == newList ? list : new SerializedList(newList);
+    }
+
+    public static void send(int id, Object data) {
+        Object msg = data;
+        RChannel channel = getChannelFromId(id);
+        if (msg instanceof RList) {
+            try {
+                msg = createShareable((RList) msg);
+            } catch (IOException x) {
+                throw RError.error(RError.NO_NODE, RError.Message.GENERIC, "error creating shareable list");
+            }
+        } else if (!(msg instanceof RFunction || msg instanceof REnvironment || msg instanceof RConnection || msg instanceof RLanguage)) {
+            // make sure that what's passed through the channel will be copied on the first
+            // update
+            makeShared(msg);
         } else {
             msg = RSerialize.serialize(msg, false, true, RSerialize.DEFAULT_VERSION, null);
         }
@@ -161,11 +219,30 @@ public class RChannel {
         }
     }
 
+    @TruffleBoundary
+    private static void unserializeList(RList list) throws IOException {
+        for (int i = 0; i < list.getLength(); i++) {
+            Object el = list.getDataAt(i);
+            if (el instanceof SerializedList) {
+                RList elList = ((SerializedList) el).getList();
+                unserializeList(elList);
+                list.updateDataAtAsObject(i, elList, null);
+            } else if (el instanceof byte[]) {
+                list.updateDataAt(i, RSerialize.unserialize((byte[]) el, null, null), null);
+            }
+        }
+    }
+
     public static Object receive(int id) {
         RChannel channel = getChannelFromId(id);
         try {
             Object msg = (id < 0 ? channel.masterToClient : channel.clientToMaster).take();
-            if (msg instanceof byte[]) {
+            if (msg instanceof SerializedList) {
+                RList list = ((SerializedList) msg).getList();
+                // list is already private (a shallow copy - do the appropriate changes in place)
+                unserializeList(list);
+                return list;
+            } else if (msg instanceof byte[]) {
                 return RSerialize.unserialize((byte[]) msg, null, null);
             } else {
                 return msg;
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/fastr/TestChannels.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/fastr/TestChannels.java
index eed50d42ae..87fe609334 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/fastr/TestChannels.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/fastr/TestChannels.java
@@ -32,6 +32,7 @@ public class TestChannels extends TestBase {
     public void testChannels() {
         assertEvalFastR("{ ch <- fastr.channel.create(1L); cx <- fastr.context.create(\"SHARED_NOTHING\"); fastr.context.spawn(cx, \"ch <- fastr.channel.get(1L); x<-fastr.channel.receive(ch); x[1]<-7; fastr.channel.send(ch, x)\"); y<-c(42); fastr.channel.send(ch, y); x<-fastr.channel.receive(ch); fastr.context.join(cx); fastr.channel.close(ch); print(c(x,y)) }",
                         "print(c(7L, 42L))");
+        assertEvalFastR("{ ch <- fastr.channel.create(1L); cx <- fastr.context.create(\"SHARED_NOTHING\"); fastr.context.spawn(cx, \"ch <- fastr.channel.get(1L); x<-fastr.channel.receive(ch); x[1][1]<-7; fastr.channel.send(ch, x)\"); y<-list(c(42)); fastr.channel.send(ch, y); x<-fastr.channel.receive(ch); fastr.context.join(cx); fastr.channel.close(ch); print(c(x,y)) }",
+                        "print(list(7L, 42L))");
     }
-
 }
-- 
GitLab