From 23b0f8df2cbb2af8747c74e1640fc2a866549aa0 Mon Sep 17 00:00:00 2001
From: Zbynek Slajchrt <zbynek.slajchrt@oracle.com>
Date: Mon, 26 Jun 2017 11:02:19 +0200
Subject: [PATCH] The buffer enlargement implemented in RSerialize

---
 .../r/nodes/function/ArgumentMatcher.java     |   1 +
 .../oracle/truffle/r/runtime/RSerialize.java  | 192 +++++++++++-------
 .../r/test/runtime/TestRSerialize.java        | 115 +++++++++++
 3 files changed, 240 insertions(+), 68 deletions(-)
 create mode 100644 com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/runtime/TestRSerialize.java

diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java
index 778adee698..c1ee404c67 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentMatcher.java
@@ -456,6 +456,7 @@ public class ArgumentMatcher {
      *
      * @see com.oracle.truffle.r.nodes.function.PromiseNode.InlineVarArgsNode
      */
+    @SuppressWarnings("javadoc")
     private static RNode updateInlinedArg(RNode node) {
         if (!(node instanceof WrapArgumentNode)) {
             return node;
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RSerialize.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RSerialize.java
index 8473264dd4..15f8e5cea2 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RSerialize.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RSerialize.java
@@ -1060,9 +1060,88 @@ public class RSerialize {
 
         private static final int READ_BUFFER_SIZE = 32 * 1024;
 
-        private final byte[] buf;
-        private int size;
-        private int offset;
+        private final class Buffer {
+            private final byte[] buf;
+            private int size;
+            private int offset;
+
+            Buffer(byte[] buf) {
+                this.buf = buf;
+            }
+
+            int readInt() {
+                return ((buf[offset++] & 0xff) << 24 | (buf[offset++] & 0xff) << 16 | (buf[offset++] & 0xff) << 8 | (buf[offset++] & 0xff));
+            }
+
+            double readDouble() {
+                long val = ((long) (buf[offset++] & 0xff) << 56 | (long) (buf[offset++] & 0xff) << 48 | (long) (buf[offset++] & 0xff) << 40 | (long) (buf[offset++] & 0xff) << 32 |
+                                (long) (buf[offset++] & 0xff) << 24 | (long) (buf[offset++] & 0xff) << 16 | (long) (buf[offset++] & 0xff) << 8 | buf[offset++] & 0xff);
+                return Double.longBitsToDouble(val);
+            }
+
+            @SuppressWarnings("deprecation")
+            String readString(int len) {
+                /*
+                 * This fast path uses a cheaper String constructor if all incoming bytes are in the
+                 * 0-127 range.
+                 */
+                boolean fastEncode = true;
+                for (int i = 0; i < len; i++) {
+                    byte b = buf[offset + i];
+                    if (b < 0) {
+                        fastEncode = false;
+                        break;
+                    }
+                }
+                String result;
+                if (fastEncode) {
+                    result = new String(buf, 0, offset, len);
+                } else {
+                    result = new String(buf, offset, len, StandardCharsets.UTF_8);
+                }
+                offset += len;
+                WeakReference<String> entry;
+                if ((entry = strings.get(result)) != null) {
+                    String string = entry.get();
+                    if (string != null) {
+                        return string;
+                    }
+                }
+                strings.put(result, new WeakReference<>(result));
+                return result;
+            }
+
+            void readRaw(byte[] data) {
+                System.arraycopy(buf, offset, data, 0, data.length);
+                offset += data.length;
+            }
+
+            void readData(int n) throws IOException {
+                if (offset + n > size) {
+                    if (offset != size) {
+                        // copy end piece to beginning
+                        System.arraycopy(buf, offset, buf, 0, size - offset);
+                    }
+                    size -= offset;
+                    offset = 0;
+                    while (size < n) {
+                        // read some more data
+                        int nread = is.read(buf, size, buf.length - size);
+                        if (nread <= 0) {
+                            throw RInternalError.unimplemented("handle unexpected eof");
+                        }
+                        size += nread;
+                    }
+                }
+            }
+        }
+
+        /**
+         * This buffer is used under normal circumstances, i.e. when the read data blocks are
+         * smaller than the initial buffer. The ensureData method creates a special buffer for
+         * reading big chunks of data exceeding the default buffer.
+         */
+        private final Buffer defaultBuffer;
 
         private final WeakHashMap<String, WeakReference<String>> strings = RContext.getInstance().stringMap;
 
@@ -1071,92 +1150,69 @@ public class RSerialize {
             if (is instanceof PByteArrayInputStream) {
                 // we already have the data and we have read the beginning
                 PByteArrayInputStream pbis = (PByteArrayInputStream) is;
-                buf = pbis.getData();
-                size = pbis.getData().length;
-                offset = pbis.pos();
+                defaultBuffer = new Buffer(pbis.getData());
+                defaultBuffer.size = pbis.getData().length;
+                defaultBuffer.offset = pbis.pos();
             } else {
-                buf = new byte[READ_BUFFER_SIZE];
-                size = 0;
-                offset = 0;
+                defaultBuffer = new Buffer(new byte[READ_BUFFER_SIZE]);
+                defaultBuffer.size = 0;
+                defaultBuffer.offset = 0;
             }
         }
 
         @Override
         int readInt() throws IOException {
-            ensureData(4);
-            return ((buf[offset++] & 0xff) << 24 | (buf[offset++] & 0xff) << 16 | (buf[offset++] & 0xff) << 8 | (buf[offset++] & 0xff));
+            return ensureData(4).readInt();
         }
 
         @Override
         double readDouble() throws IOException {
-            ensureData(8);
-            long val = ((long) (buf[offset++] & 0xff) << 56 | (long) (buf[offset++] & 0xff) << 48 | (long) (buf[offset++] & 0xff) << 40 | (long) (buf[offset++] & 0xff) << 32 |
-                            (long) (buf[offset++] & 0xff) << 24 | (long) (buf[offset++] & 0xff) << 16 | (long) (buf[offset++] & 0xff) << 8 | buf[offset++] & 0xff);
-            return Double.longBitsToDouble(val);
+            return ensureData(8).readDouble();
         }
 
         @SuppressWarnings("deprecation")
         @Override
         String readString(int len) throws IOException {
-            ensureData(len);
-            /*
-             * This fast path uses a cheaper String constructor if all incoming bytes are in the
-             * 0-127 range.
-             */
-            boolean fastEncode = true;
-            for (int i = 0; i < len; i++) {
-                byte b = buf[offset + i];
-                if (b < 0) {
-                    fastEncode = false;
-                    break;
-                }
-            }
-            String result;
-            if (fastEncode) {
-                result = new String(buf, 0, offset, len);
-            } else {
-                result = new String(buf, offset, len, StandardCharsets.UTF_8);
-            }
-            offset += len;
-            WeakReference<String> entry;
-            if ((entry = strings.get(result)) != null) {
-                String string = entry.get();
-                if (string != null) {
-                    return string;
-                }
-            }
-            strings.put(result, new WeakReference<>(result));
-            return result;
+            return ensureData(len).readString(len);
         }
 
-        private void ensureData(int n) throws IOException {
-            if (n > buf.length) {
-                throw RInternalError.unimplemented("dynamically enlarge buffer");
-            }
-            if (offset + n > size) {
-                if (offset != size) {
-                    // copy end piece to beginning
-                    System.arraycopy(buf, offset, buf, 0, size - offset);
-                }
-                size -= offset;
-                offset = 0;
-                while (size < n) {
-                    // read some more data
-                    int nread = is.read(buf, size, buf.length - size);
-                    if (nread <= 0) {
-                        throw RInternalError.unimplemented("handle unexpected eof");
-                    }
-                    size += nread;
+        @Override
+        void readRaw(byte[] data) throws IOException {
+            ensureData(data.length).readRaw(data);
+        }
+
+        private Buffer ensureData(int n) throws IOException {
+            Buffer usedBuffer;
+            if (n > defaultBuffer.buf.length) {
+                if (is instanceof PByteArrayInputStream) {
+                    // If the input stream is instance of PByteArrayInputStream, the buffer is
+                    // preloaded and thus no more data can be read beyond the current buffer.
+                    throw new IOException("Premature EOF");
                 }
+
+                // create an enlarged copy of the default buffer
+                byte[] enlargedBuf = new byte[n];
+                System.arraycopy(defaultBuffer.buf, defaultBuffer.offset, enlargedBuf, defaultBuffer.offset, defaultBuffer.size - defaultBuffer.offset);
+                usedBuffer = new Buffer(enlargedBuf);
+                usedBuffer.offset = defaultBuffer.offset;
+                usedBuffer.size = defaultBuffer.size;
+
+                // reset the default buffer
+                defaultBuffer.offset = defaultBuffer.size = 0;
+
+                usedBuffer.readData(n);
+                // The previous statement should entirely fill the temporary buffer.
+                // It is assumed that the caller will read n bytes, making the temporary buffer
+                // disposable. Next time, the default buffer will be used again, unless
+                // n > defaultBuffer.buf.length.
+                assert usedBuffer.size == n;
+            } else {
+                usedBuffer = defaultBuffer;
+                usedBuffer.readData(n);
             }
+            return usedBuffer;
         }
 
-        @Override
-        void readRaw(byte[] data) throws IOException {
-            ensureData(data.length);
-            System.arraycopy(buf, offset, data, 0, data.length);
-            offset += data.length;
-        }
     }
 
     /**
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/runtime/TestRSerialize.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/runtime/TestRSerialize.java
new file mode 100644
index 0000000000..4304919080
--- /dev/null
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/runtime/TestRSerialize.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+package com.oracle.truffle.r.test.runtime;
+
+import java.util.Arrays;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import com.oracle.truffle.r.runtime.RSerialize;
+import com.oracle.truffle.r.runtime.data.RDataFactory;
+import com.oracle.truffle.r.runtime.data.RStringVector;
+import com.oracle.truffle.r.test.TestBase;
+
+public class TestRSerialize extends TestBase {
+
+    // Buffer enlargement tests
+
+    @Test
+    public void testDeserializeLongString() {
+        char[] chars = new char[1000000];
+        Arrays.fill(chars, 'x');
+        String longString = new String(chars);
+        RStringVector longStringVec = RDataFactory.createStringVector(longString);
+        byte[] serialized = RSerialize.serialize(longStringVec, RSerialize.XDR, RSerialize.DEFAULT_VERSION, null);
+        Object unserialized = RSerialize.unserialize(RDataFactory.createRawVector(serialized));
+
+        Assert.assertTrue(unserialized instanceof RStringVector);
+        Assert.assertEquals(1, ((RStringVector) unserialized).getLength());
+        Assert.assertEquals(longString, ((RStringVector) unserialized).getDataAt(0));
+    }
+
+    @Test
+    public void testDeserializeShortLongStrings() {
+        char[] chars = new char[1000000];
+        Arrays.fill(chars, 'x');
+        String longString = new String(chars);
+        RStringVector longStringVec = RDataFactory.createStringVector(new String[]{"abc", longString}, true);
+        byte[] serialized = RSerialize.serialize(longStringVec, RSerialize.XDR, RSerialize.DEFAULT_VERSION, null);
+        Object unserialized = RSerialize.unserialize(RDataFactory.createRawVector(serialized));
+
+        Assert.assertTrue(unserialized instanceof RStringVector);
+        Assert.assertEquals(2, ((RStringVector) unserialized).getLength());
+        Assert.assertEquals("abc", ((RStringVector) unserialized).getDataAt(0));
+        Assert.assertEquals(longString, ((RStringVector) unserialized).getDataAt(1));
+    }
+
+    @Test
+    public void testDeserializeLongShortStrings() {
+        char[] chars = new char[1000000];
+        Arrays.fill(chars, 'x');
+        String longString = new String(chars);
+        RStringVector longStringVec = RDataFactory.createStringVector(new String[]{longString, "abc"}, true);
+        byte[] serialized = RSerialize.serialize(longStringVec, RSerialize.XDR, RSerialize.DEFAULT_VERSION, null);
+        Object unserialized = RSerialize.unserialize(RDataFactory.createRawVector(serialized));
+
+        Assert.assertTrue(unserialized instanceof RStringVector);
+        Assert.assertEquals(2, ((RStringVector) unserialized).getLength());
+        Assert.assertEquals(longString, ((RStringVector) unserialized).getDataAt(0));
+        Assert.assertEquals("abc", ((RStringVector) unserialized).getDataAt(1));
+    }
+
+    @Test
+    public void testDeserializeShortLongShortStrings() {
+        char[] chars = new char[1000000];
+        Arrays.fill(chars, 'x');
+        String longString = new String(chars);
+        RStringVector longStringVec = RDataFactory.createStringVector(new String[]{"abc", longString, "abc"}, true);
+        byte[] serialized = RSerialize.serialize(longStringVec, RSerialize.XDR, RSerialize.DEFAULT_VERSION, null);
+        Object unserialized = RSerialize.unserialize(RDataFactory.createRawVector(serialized));
+
+        Assert.assertTrue(unserialized instanceof RStringVector);
+        Assert.assertEquals(3, ((RStringVector) unserialized).getLength());
+        Assert.assertEquals("abc", ((RStringVector) unserialized).getDataAt(0));
+        Assert.assertEquals(longString, ((RStringVector) unserialized).getDataAt(1));
+        Assert.assertEquals("abc", ((RStringVector) unserialized).getDataAt(2));
+    }
+
+    @Test
+    public void testDeserializeShortLongShortLongStrings() {
+        char[] chars = new char[1000000];
+        Arrays.fill(chars, 'x');
+        String longString = new String(chars);
+        RStringVector longStringVec = RDataFactory.createStringVector(new String[]{"abc", longString, "abc", longString}, true);
+        byte[] serialized = RSerialize.serialize(longStringVec, RSerialize.XDR, RSerialize.DEFAULT_VERSION, null);
+        Object unserialized = RSerialize.unserialize(RDataFactory.createRawVector(serialized));
+
+        Assert.assertTrue(unserialized instanceof RStringVector);
+        Assert.assertEquals(4, ((RStringVector) unserialized).getLength());
+        Assert.assertEquals("abc", ((RStringVector) unserialized).getDataAt(0));
+        Assert.assertEquals(longString, ((RStringVector) unserialized).getDataAt(1));
+        Assert.assertEquals("abc", ((RStringVector) unserialized).getDataAt(2));
+        Assert.assertEquals(longString, ((RStringVector) unserialized).getDataAt(3));
+    }
+}
-- 
GitLab