From d9e18431c7184655544c3e2a10fce545b929e1d5 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Fri, 3 Mar 2017 13:15:45 +0100
Subject: [PATCH] class cache in Unique

---
 .../truffle/r/nodes/builtin/base/Unique.java   | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java
index 8898ee1611..88a1f1c89d 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java
@@ -93,8 +93,11 @@ public abstract class Unique extends RBuiltinNode {
     }
 
     @SuppressWarnings("unused")
-    @Specialization
-    protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, int nmax) {
+    @Specialization(guards = "vecIn.getClass() == vecClass")
+    protected RStringVector doUniqueCachedString(RAbstractStringVector vecIn, byte incomparables, byte fromLast, int nmax,
+                    @Cached("vecIn.getClass()") Class<? extends RAbstractStringVector> vecClass) {
+        RAbstractStringVector vec = vecClass.cast(vecIn);
+        reportWork(vec.getLength());
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             NonRecursiveHashSet<String> set = new NonRecursiveHashSet<>(vec.getLength());
             String[] data = new String[vec.getLength()];
@@ -120,6 +123,11 @@ public abstract class Unique extends RBuiltinNode {
         }
     }
 
+    @Specialization(replaces = "doUniqueCachedString")
+    protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, int nmax) {
+        return doUniqueCachedString(vec, incomparables, fromLast, nmax, RAbstractStringVector.class);
+    }
+
     // these are intended to stay private as they will go away once we figure out which external
     // library to use
 
@@ -245,6 +253,7 @@ public abstract class Unique extends RBuiltinNode {
     protected RIntVector doUniqueCached(RAbstractIntVector vecIn, byte incomparables, byte fromLast, int nmax,
                     @Cached("vecIn.getClass()") Class<? extends RAbstractIntVector> vecClass) {
         RAbstractIntVector vec = vecClass.cast(vecIn);
+        reportWork(vec.getLength());
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             NonRecursiveHashSetInt set = new NonRecursiveHashSetInt();
             int[] data = new int[16];
@@ -296,6 +305,7 @@ public abstract class Unique extends RBuiltinNode {
     @Specialization(guards = "!lengthOne(list)")
     @TruffleBoundary
     protected RList doUnique(RList list, byte incomparables, byte fromLast, int nmax) {
+        reportWork(list.getLength());
         /*
          * Brute force, as manual says: Using this for lists is potentially slow, especially if the
          * elements are not atomic vectors (see vector) or differ only in their attributes. In the
@@ -375,6 +385,7 @@ public abstract class Unique extends RBuiltinNode {
     @SuppressWarnings("unused")
     @Specialization
     protected RDoubleVector doUnique(RAbstractDoubleVector vec, byte incomparables, byte fromLast, int nmax) {
+        reportWork(vec.getLength());
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             NonRecursiveHashSetDouble set = new NonRecursiveHashSetDouble(vec.getLength());
             double[] data = new double[vec.getLength()];
@@ -401,6 +412,7 @@ public abstract class Unique extends RBuiltinNode {
     @SuppressWarnings("unused")
     @Specialization
     protected RLogicalVector doUnique(RAbstractLogicalVector vec, byte incomparables, byte fromLast, int nmax) {
+        reportWork(vec.getLength());
         ByteArray dataList = new ByteArray(vec.getLength());
         for (int i = 0; i < vec.getLength(); i++) {
             byte val = vec.getDataAt(i);
@@ -414,6 +426,7 @@ public abstract class Unique extends RBuiltinNode {
     @SuppressWarnings("unused")
     @Specialization
     protected RComplexVector doUnique(RAbstractComplexVector vec, byte incomparables, byte fromLast, int nmax) {
+        reportWork(vec.getLength());
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             NonRecursiveHashSet<RComplex> set = new NonRecursiveHashSet<>(vec.getLength());
             double[] data = new double[vec.getLength() * 2];
@@ -441,6 +454,7 @@ public abstract class Unique extends RBuiltinNode {
     @SuppressWarnings("unused")
     @Specialization
     protected RRawVector doUnique(RAbstractRawVector vec, byte incomparables, byte fromLast, int nmax) {
+        reportWork(vec.getLength());
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             NonRecursiveHashSet<RRaw> set = new NonRecursiveHashSet<>(vec.getLength());
             byte[] data = new byte[vec.getLength()];
-- 
GitLab