From aee854fbc5f65173f997904046ca60bdc13909a4 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Fri, 17 Nov 2017 13:22:18 +0100
Subject: [PATCH] add VectorAccess for primitive types

---
 .../data/nodes/FastPathVectorAccess.java      |  18 +--
 .../data/nodes/PrimitiveVectorAccess.java     | 148 ++++++++++++++++++
 .../data/nodes/SlowPathVectorAccess.java      |   4 +-
 .../r/runtime/data/nodes/VectorAccess.java    |  73 +++++++--
 4 files changed, 216 insertions(+), 27 deletions(-)
 create mode 100644 com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/PrimitiveVectorAccess.java

diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/FastPathVectorAccess.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/FastPathVectorAccess.java
index e17d63ac7c..4d8948d3f6 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/FastPathVectorAccess.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/FastPathVectorAccess.java
@@ -44,8 +44,8 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     protected boolean naReported; // TODO: move this into the iterator
 
-    protected FastPathVectorAccess(RAbstractContainer value) {
-        super(value.getClass(), value.getInternalStore() != null);
+    protected FastPathVectorAccess(Object value) {
+        super(value.getClass(), value instanceof RAbstractContainer ? ((RAbstractContainer) value).getInternalStore() != null : true);
     }
 
     @Override
@@ -63,7 +63,7 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     public abstract static class FastPathFromIntAccess extends FastPathVectorAccess {
 
-        public FastPathFromIntAccess(RAbstractContainer value) {
+        public FastPathFromIntAccess(Object value) {
             super(value);
         }
 
@@ -147,7 +147,7 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     public abstract static class FastPathFromDoubleAccess extends FastPathVectorAccess {
 
-        public FastPathFromDoubleAccess(RAbstractContainer value) {
+        public FastPathFromDoubleAccess(Object value) {
             super(value);
         }
 
@@ -240,7 +240,7 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     public abstract static class FastPathFromLogicalAccess extends FastPathVectorAccess {
 
-        public FastPathFromLogicalAccess(RAbstractContainer value) {
+        public FastPathFromLogicalAccess(Object value) {
             super(value);
         }
 
@@ -323,7 +323,7 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     public abstract static class FastPathFromRawAccess extends FastPathVectorAccess {
 
-        public FastPathFromRawAccess(RAbstractContainer value) {
+        public FastPathFromRawAccess(Object value) {
             super(value);
         }
 
@@ -398,7 +398,7 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     public abstract static class FastPathFromComplexAccess extends FastPathVectorAccess {
 
-        public FastPathFromComplexAccess(RAbstractContainer value) {
+        public FastPathFromComplexAccess(Object value) {
             super(value);
         }
 
@@ -491,7 +491,7 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     public abstract static class FastPathFromStringAccess extends FastPathVectorAccess {
 
-        public FastPathFromStringAccess(RAbstractContainer value) {
+        public FastPathFromStringAccess(Object value) {
             super(value);
         }
 
@@ -564,7 +564,7 @@ public abstract class FastPathVectorAccess extends VectorAccess {
 
     public abstract static class FastPathFromListAccess extends FastPathVectorAccess {
 
-        public FastPathFromListAccess(RAbstractContainer value) {
+        public FastPathFromListAccess(Object value) {
             super(value);
         }
 
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/PrimitiveVectorAccess.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/PrimitiveVectorAccess.java
new file mode 100644
index 0000000000..9e92d40375
--- /dev/null
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/PrimitiveVectorAccess.java
@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+package com.oracle.truffle.r.runtime.data.nodes;
+
+import com.oracle.truffle.r.runtime.RInternalError;
+import com.oracle.truffle.r.runtime.RType;
+import com.oracle.truffle.r.runtime.data.RNull;
+import com.oracle.truffle.r.runtime.data.nodes.FastPathVectorAccess.FastPathFromDoubleAccess;
+import com.oracle.truffle.r.runtime.data.nodes.FastPathVectorAccess.FastPathFromIntAccess;
+import com.oracle.truffle.r.runtime.data.nodes.FastPathVectorAccess.FastPathFromListAccess;
+import com.oracle.truffle.r.runtime.data.nodes.FastPathVectorAccess.FastPathFromLogicalAccess;
+import com.oracle.truffle.r.runtime.data.nodes.FastPathVectorAccess.FastPathFromStringAccess;
+import com.oracle.truffle.r.runtime.data.nodes.SlowPathVectorAccess.SlowPathFromDoubleAccess;
+import com.oracle.truffle.r.runtime.data.nodes.SlowPathVectorAccess.SlowPathFromIntAccess;
+import com.oracle.truffle.r.runtime.data.nodes.SlowPathVectorAccess.SlowPathFromListAccess;
+import com.oracle.truffle.r.runtime.data.nodes.SlowPathVectorAccess.SlowPathFromLogicalAccess;
+import com.oracle.truffle.r.runtime.data.nodes.SlowPathVectorAccess.SlowPathFromStringAccess;
+
+public abstract class PrimitiveVectorAccess {
+
+    public static VectorAccess create(Object value) {
+        if (value instanceof Integer) {
+            return new FastPathFromIntAccess(value) {
+                @Override
+                protected int getInt(Object store, int index) {
+                    return (Integer) store;
+                }
+            };
+        } else if (value instanceof Double) {
+            return new FastPathFromDoubleAccess(value) {
+                @Override
+                protected double getDouble(Object store, int index) {
+                    return (Double) store;
+                }
+            };
+        } else if (value instanceof Byte) {
+            return new FastPathFromLogicalAccess(value) {
+                @Override
+                protected byte getLogical(Object store, int index) {
+                    return (Byte) store;
+                }
+            };
+        } else if (value instanceof String) {
+            return new FastPathFromStringAccess(value) {
+                @Override
+                protected String getString(Object store, int index) {
+                    return (String) store;
+                }
+            };
+        } else if (value instanceof RNull) {
+            return new FastPathFromListAccess(value) {
+                @Override
+                public RType getType() {
+                    return RType.Null;
+                }
+
+                @Override
+                protected int getLength(Object vector) {
+                    return 0;
+                }
+
+                @Override
+                protected Object getListElement(Object store, int index) {
+                    throw RInternalError.shouldNotReachHere();
+                }
+            };
+        } else {
+            return null;
+        }
+    }
+
+    private static final SlowPathFromIntAccess SLOW_PATH_INT = new SlowPathFromIntAccess() {
+        @Override
+        protected int getInt(Object store, int index) {
+            return (Integer) store;
+        }
+    };
+    private static final SlowPathFromDoubleAccess SLOW_PATH_DOUBLE = new SlowPathFromDoubleAccess() {
+        @Override
+        protected double getDouble(Object store, int index) {
+            return (Double) store;
+        }
+    };
+    private static final SlowPathFromLogicalAccess SLOW_PATH_LOGICAL = new SlowPathFromLogicalAccess() {
+        @Override
+        protected byte getLogical(Object store, int index) {
+            return (Byte) store;
+        }
+    };
+    private static final SlowPathFromStringAccess SLOW_PATH_STRING = new SlowPathFromStringAccess() {
+        @Override
+        protected String getString(Object store, int index) {
+            return (String) store;
+        }
+    };
+    private static final SlowPathFromListAccess SLOW_PATH_NULL = new SlowPathFromListAccess() {
+        @Override
+        public RType getType() {
+            return RType.Null;
+        }
+
+        @Override
+        protected int getLength(Object vector) {
+            return 0;
+        }
+
+        @Override
+        protected Object getListElement(Object store, int index) {
+            throw RInternalError.shouldNotReachHere();
+        }
+    };
+
+    public static VectorAccess createSlowPath(Object value) {
+        if (value instanceof Integer) {
+            return SLOW_PATH_INT;
+        } else if (value instanceof Double) {
+            return SLOW_PATH_DOUBLE;
+        } else if (value instanceof Byte) {
+            return SLOW_PATH_LOGICAL;
+        } else if (value instanceof String) {
+            return SLOW_PATH_STRING;
+        } else if (value instanceof RNull) {
+            return SLOW_PATH_NULL;
+        } else {
+            return null;
+        }
+    }
+}
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/SlowPathVectorAccess.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/SlowPathVectorAccess.java
index 18fb51f466..5aeab7e217 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/SlowPathVectorAccess.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/SlowPathVectorAccess.java
@@ -43,8 +43,8 @@ public abstract class SlowPathVectorAccess extends VectorAccess {
     protected boolean naReported; // TODO: move this into the iterator
 
     protected SlowPathVectorAccess() {
-        // VectorAccess.supports has an assertion that relies on this being RAbstractContainer.class
-        super(RAbstractContainer.class, true);
+        // VectorAccess.supports has an assertion that relies on this being Object.class
+        super(Object.class, true);
     }
 
     @Override
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorAccess.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorAccess.java
index faa2dfb98c..e64b834867 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorAccess.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorAccess.java
@@ -23,6 +23,7 @@
 package com.oracle.truffle.r.runtime.data.nodes;
 
 import com.oracle.truffle.api.CompilerAsserts;
+import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.nodes.Node;
@@ -69,10 +70,10 @@ public abstract class VectorAccess extends Node {
 
     public final NACheck na = NACheck.create();
 
-    protected final Class<? extends RAbstractContainer> clazz;
+    protected final Class<?> clazz;
     protected final boolean hasStore;
 
-    public VectorAccess(Class<? extends RAbstractContainer> clazz, boolean hasStore) {
+    public VectorAccess(Class<?> clazz, boolean hasStore) {
         CompilerAsserts.neverPartOfCompilation();
         this.clazz = clazz;
         this.hasStore = hasStore;
@@ -204,13 +205,17 @@ public abstract class VectorAccess extends Node {
 
     protected abstract boolean isNA(Object store, int index);
 
-    public final RAbstractContainer cast(Object value) {
+    public final Object cast(Object value) {
         return clazz.cast(value);
     }
 
     public final boolean supports(Object value) {
-        assert clazz != RAbstractContainer.class : "cannot call 'supports' on slow path vector access";
-        return value.getClass() == clazz && (cast(value).getInternalStore() != null) == hasStore;
+        assert clazz != Object.class : "cannot call 'supports' on slow path vector access";
+        if (value.getClass() != clazz) {
+            return false;
+        }
+        Object castVector = cast(value);
+        return !(castVector instanceof RAbstractContainer) || (((RAbstractContainer) castVector).getInternalStore() != null) == hasStore;
     }
 
     protected abstract Object getStore(RAbstractContainer vector);
@@ -219,16 +224,26 @@ public abstract class VectorAccess extends Node {
         return vector.getLength();
     }
 
+    protected int getLength(@SuppressWarnings("unused") Object vector) {
+        return 1;
+    }
+
     /**
      * Creates a new iterator that will point to before the beginning of the vector, so that
      * {@link #next(SequentialIterator)} will move it to the first element.
      */
-    public final SequentialIterator access(RAbstractContainer vector) {
-        RAbstractContainer castVector = cast(vector);
-        int length = getLength(castVector);
-        RBaseNode.reportWork(this, length);
-        na.enable(castVector);
-        return new SequentialIterator(getStore(castVector), length);
+    public final SequentialIterator access(Object vector) {
+        Object castVector = cast(vector);
+        if (castVector instanceof RAbstractContainer) {
+            RAbstractContainer container = (RAbstractContainer) castVector;
+            int length = getLength(container);
+            RBaseNode.reportWork(this, length);
+            na.enable(container);
+            return new SequentialIterator(getStore(container), length);
+        } else {
+            na.enable(true);
+            return new SequentialIterator(castVector, getLength(castVector));
+        }
     }
 
     @SuppressWarnings("static-method")
@@ -344,11 +359,17 @@ public abstract class VectorAccess extends Node {
      * Creates a new random access on the given vector.
      */
     public final RandomIterator randomAccess(RAbstractContainer vector) {
-        RAbstractContainer castVector = cast(vector);
-        int length = getLength(castVector);
-        RBaseNode.reportWork(this, length);
-        na.enable(castVector);
-        return new RandomIterator(getStore(castVector), length);
+        Object castVector = cast(vector);
+        if (castVector instanceof RAbstractContainer) {
+            RAbstractContainer container = (RAbstractContainer) castVector;
+            int length = getLength(container);
+            RBaseNode.reportWork(this, length);
+            na.enable(container);
+            return new RandomIterator(getStore(container), length);
+        } else {
+            na.enable(true);
+            return new RandomIterator(castVector, getLength(castVector));
+        }
     }
 
     @SuppressWarnings("static-method")
@@ -451,6 +472,7 @@ public abstract class VectorAccess extends Node {
     }
 
     public static VectorAccess createNew(RType type) {
+        CompilerAsserts.neverPartOfCompilation();
         switch (type) {
             case Character:
                 return Lazy.TEMPLATE_CHARACTER.access();
@@ -477,6 +499,7 @@ public abstract class VectorAccess extends Node {
         }
     }
 
+    @TruffleBoundary
     public static VectorAccess createSlowPathNew(RType type) {
         switch (type) {
             case Character:
@@ -503,4 +526,22 @@ public abstract class VectorAccess extends Node {
                 throw RInternalError.shouldNotReachHere();
         }
     }
+
+    public static VectorAccess create(Object value) {
+        CompilerAsserts.neverPartOfCompilation();
+        if (value instanceof RAbstractContainer) {
+            return ((RAbstractContainer) value).access();
+        } else {
+            return PrimitiveVectorAccess.create(value);
+        }
+    }
+
+    @TruffleBoundary
+    public static VectorAccess createSlowPath(Object value) {
+        if (value instanceof RAbstractContainer) {
+            return ((RAbstractContainer) value).slowPathAccess();
+        } else {
+            return PrimitiveVectorAccess.createSlowPath(value);
+        }
+    }
 }
-- 
GitLab