diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java index 8898ee16115b0af79a6551d5cb9347730143e9b2..88a1f1c89d2645a37e916f5b9ef8b51e50f3f217 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java @@ -93,8 +93,11 @@ public abstract class Unique extends RBuiltinNode { } @SuppressWarnings("unused") - @Specialization - protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, int nmax) { + @Specialization(guards = "vecIn.getClass() == vecClass") + protected RStringVector doUniqueCachedString(RAbstractStringVector vecIn, byte incomparables, byte fromLast, int nmax, + @Cached("vecIn.getClass()") Class<? extends RAbstractStringVector> vecClass) { + RAbstractStringVector vec = vecClass.cast(vecIn); + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSet<String> set = new NonRecursiveHashSet<>(vec.getLength()); String[] data = new String[vec.getLength()]; @@ -120,6 +123,11 @@ public abstract class Unique extends RBuiltinNode { } } + @Specialization(replaces = "doUniqueCachedString") + protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, int nmax) { + return doUniqueCachedString(vec, incomparables, fromLast, nmax, RAbstractStringVector.class); + } + // these are intended to stay private as they will go away once we figure out which external // library to use @@ -245,6 +253,7 @@ public abstract class Unique extends RBuiltinNode { protected RIntVector doUniqueCached(RAbstractIntVector vecIn, byte incomparables, byte fromLast, int nmax, @Cached("vecIn.getClass()") Class<? extends RAbstractIntVector> vecClass) { RAbstractIntVector vec = vecClass.cast(vecIn); + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSetInt set = new NonRecursiveHashSetInt(); int[] data = new int[16]; @@ -296,6 +305,7 @@ public abstract class Unique extends RBuiltinNode { @Specialization(guards = "!lengthOne(list)") @TruffleBoundary protected RList doUnique(RList list, byte incomparables, byte fromLast, int nmax) { + reportWork(list.getLength()); /* * Brute force, as manual says: Using this for lists is potentially slow, especially if the * elements are not atomic vectors (see vector) or differ only in their attributes. In the @@ -375,6 +385,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization protected RDoubleVector doUnique(RAbstractDoubleVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSetDouble set = new NonRecursiveHashSetDouble(vec.getLength()); double[] data = new double[vec.getLength()]; @@ -401,6 +412,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization protected RLogicalVector doUnique(RAbstractLogicalVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); ByteArray dataList = new ByteArray(vec.getLength()); for (int i = 0; i < vec.getLength(); i++) { byte val = vec.getDataAt(i); @@ -414,6 +426,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization protected RComplexVector doUnique(RAbstractComplexVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSet<RComplex> set = new NonRecursiveHashSet<>(vec.getLength()); double[] data = new double[vec.getLength() * 2]; @@ -441,6 +454,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization protected RRawVector doUnique(RAbstractRawVector vec, byte incomparables, byte fromLast, int nmax) { + reportWork(vec.getLength()); if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSet<RRaw> set = new NonRecursiveHashSet<>(vec.getLength()); byte[] data = new byte[vec.getLength()];