From 7e9da700f853bdf96feb4721a39897fcfc8f5afe Mon Sep 17 00:00:00 2001
From: Adam Welc <adam.welc@oracle.com>
Date: Sun, 28 Aug 2016 14:05:35 -0700
Subject: [PATCH] Rewritten parameter casts for the unique builtin.

---
 .../truffle/r/nodes/builtin/base/Unique.java  | 42 +++++++++++++------
 .../r/test/builtins/TestBuiltin_unique.java   |  5 +++
 2 files changed, 35 insertions(+), 12 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java
index 74c3cda1ad..ea894dcbba 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java
@@ -22,6 +22,7 @@
  */
 package com.oracle.truffle.r.nodes.builtin.base;
 
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
 
@@ -32,11 +33,12 @@ import com.oracle.truffle.api.CompilerDirectives;
 import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.profiles.ConditionProfile;
+import com.oracle.truffle.r.nodes.builtin.CastBuilder;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
+import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.Utils;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
-import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames;
 import com.oracle.truffle.r.runtime.data.RComplex;
 import com.oracle.truffle.r.runtime.data.RComplexVector;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
@@ -54,9 +56,9 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
+import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
 
-// Implements default S3 method
-@RBuiltin(name = "unique", kind = INTERNAL, parameterNames = {"x", "incomparables", "fromLast", "nmax", "..."}, behavior = PURE)
+@RBuiltin(name = "unique", kind = INTERNAL, parameterNames = {"x", "incomparables", "fromLast", "nmax"}, behavior = PURE)
 // TODO A more efficient implementation is in order; GNU R uses hash tables so perhaps we should
 // consider using one of the existing libraries that offer hash table implementations for primitive
 // types
@@ -66,15 +68,31 @@ public abstract class Unique extends RBuiltinNode {
 
     private final ConditionProfile bigProfile = ConditionProfile.createBinaryProfile();
 
+    @Override
+    protected void createCasts(CastBuilder casts) {
+        // these are similar to those in DuplicatedFunctions.java
+        casts.arg("x").mustBe(nullValue().or(abstractVectorValue()), RError.SHOW_CALLER, RError.Message.APPLIES_TO_VECTORS,
+                        "unique()").mapIf(nullValue().not(), asVector());
+        // not much more can be done for incomparables as it is either a vector of incomparable
+        // values or a (single) logical value
+        // TODO: coercion error must be handled by specialization as it depends on type of x (much
+        // like in duplicated)
+        casts.arg("incomparables").asVector(true);
+        casts.arg("fromLast").asLogicalVector().findFirst(RRuntime.LOGICAL_FALSE);
+        // currently not supported and not tested, but NA is a correct value (the same for empty
+        // vectors) whereas 0 is not (throws an error)
+        casts.arg("nmax").asIntegerVector().findFirst(RRuntime.INT_NA);
+    }
+
     @SuppressWarnings("unused")
     @Specialization
-    protected RNull doUnique(RNull vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RNull doUnique(RNull vec, RAbstractVector incomparables, byte fromLast, int nmax) {
         return vec;
     }
 
     @SuppressWarnings("unused")
     @Specialization
-    protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RStringVector doUnique(RAbstractStringVector vec, RAbstractVector incomparables, byte fromLast, int nmax) {
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             Utils.NonRecursiveHashSet<String> set = new Utils.NonRecursiveHashSet<>(vec.getLength());
             String[] data = new String[vec.getLength()];
@@ -230,7 +248,7 @@ public abstract class Unique extends RBuiltinNode {
 
     @SuppressWarnings("unused")
     @Specialization
-    protected RIntVector doUnique(RAbstractIntVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RIntVector doUnique(RAbstractIntVector vec, RAbstractVector incomparables, byte fromLast, int nmax) {
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             NonRecursiveHashSetInt set = new NonRecursiveHashSetInt();
             int[] data = new int[16];
@@ -259,7 +277,7 @@ public abstract class Unique extends RBuiltinNode {
 
     @SuppressWarnings("unused")
     @Specialization(guards = "lengthOne(list)")
-    protected RList doUniqueL1(RList list, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RList doUniqueL1(RList list, RAbstractVector incomparables, byte fromLast, int nmax) {
         return (RList) list.copyDropAttributes();
     }
 
@@ -276,7 +294,7 @@ public abstract class Unique extends RBuiltinNode {
     @SuppressWarnings("unused")
     @Specialization(guards = "!lengthOne(list)")
     @TruffleBoundary
-    protected RList doUnique(RList list, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RList doUnique(RList list, RAbstractVector incomparables, byte fromLast, int nmax) {
         /*
          * Brute force, as manual says: Using this for lists is potentially slow, especially if the
          * elements are not atomic vectors (see vector) or differ only in their attributes. In the
@@ -355,7 +373,7 @@ public abstract class Unique extends RBuiltinNode {
 
     @SuppressWarnings("unused")
     @Specialization
-    protected RDoubleVector doUnique(RAbstractDoubleVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RDoubleVector doUnique(RAbstractDoubleVector vec, RAbstractVector incomparables, byte fromLast, int nmax) {
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             Utils.NonRecursiveHashSetDouble set = new Utils.NonRecursiveHashSetDouble(vec.getLength());
             double[] data = new double[vec.getLength()];
@@ -381,7 +399,7 @@ public abstract class Unique extends RBuiltinNode {
 
     @SuppressWarnings("unused")
     @Specialization
-    protected RLogicalVector doUnique(RAbstractLogicalVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RLogicalVector doUnique(RAbstractLogicalVector vec, RAbstractVector incomparables, byte fromLast, int nmax) {
         ByteArray dataList = new ByteArray(vec.getLength());
         for (int i = 0; i < vec.getLength(); i++) {
             byte val = vec.getDataAt(i);
@@ -394,7 +412,7 @@ public abstract class Unique extends RBuiltinNode {
 
     @SuppressWarnings("unused")
     @Specialization
-    protected RComplexVector doUnique(RAbstractComplexVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RComplexVector doUnique(RAbstractComplexVector vec, RAbstractVector incomparables, byte fromLast, int nmax) {
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             Utils.NonRecursiveHashSet<RComplex> set = new Utils.NonRecursiveHashSet<>(vec.getLength());
             double[] data = new double[vec.getLength() * 2];
@@ -421,7 +439,7 @@ public abstract class Unique extends RBuiltinNode {
 
     @SuppressWarnings("unused")
     @Specialization
-    protected RRawVector doUnique(RAbstractRawVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) {
+    protected RRawVector doUnique(RAbstractRawVector vec, RAbstractVector incomparables, byte fromLast, int nmax) {
         if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) {
             Utils.NonRecursiveHashSet<RRaw> set = new Utils.NonRecursiveHashSet<>(vec.getLength());
             byte[] data = new byte[vec.getLength()];
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java
index ff067fd228..fb84a6827c 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java
@@ -189,5 +189,10 @@ public class TestBuiltin_unique extends TestBase {
     @Test
     public void testUnique() {
         assertEval("{x<-factor(c(\"a\", \"b\", \"a\")); unique(x) }");
+
+        assertEval("{ x<-quote(f(7, 42)); unique(x) }");
+        assertEval("{ x<-function() 42; unique(x) }");
+        assertEval(Ignored.Unknown, "{ unique(c(1,2,1), incomparables=function() 42) }");
+
     }
 }
-- 
GitLab