From 192228708c188127c82ec63ce39150e827c3f79b Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Fri, 26 Aug 2016 15:15:17 +0200
Subject: [PATCH] Rank: converted to cast pipeline

---
 .../truffle/r/nodes/builtin/base/Rank.java    | 79 ++++++++++++-------
 .../com/oracle/truffle/r/runtime/RError.java  |  3 +
 .../truffle/r/test/ExpectedTestOutput.test    | 12 +++
 .../r/test/builtins/TestBuiltin_rank.java     |  7 ++
 4 files changed, 74 insertions(+), 27 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java
index 808450ebf3..022d7ec2dc 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java
@@ -12,23 +12,36 @@
  */
 package com.oracle.truffle.r.nodes.builtin.base;
 
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.intNA;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notEmpty;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.rawValue;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
+import static com.oracle.truffle.r.runtime.RError.NO_CALLER;
+import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER;
+import static com.oracle.truffle.r.runtime.RError.Message.INVALID_TIES_FOR_RANK;
+import static com.oracle.truffle.r.runtime.RError.Message.INVALID_VALUE;
+import static com.oracle.truffle.r.runtime.RError.Message.RANK_LARGE_N;
+import static com.oracle.truffle.r.runtime.RError.Message.UNIMPLEMENTED_TYPE_IN_GREATER;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
 
+import java.util.function.Function;
+
 import com.oracle.truffle.api.dsl.Specialization;
+import com.oracle.truffle.api.profiles.BranchProfile;
 import com.oracle.truffle.r.nodes.builtin.CastBuilder;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
 import com.oracle.truffle.r.nodes.builtin.base.OrderNodeGen.CmpNodeGen;
 import com.oracle.truffle.r.nodes.builtin.base.OrderNodeGen.OrderVector1NodeGen;
 import com.oracle.truffle.r.runtime.RError;
-import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
 import com.oracle.truffle.r.runtime.data.RIntVector;
-import com.oracle.truffle.r.runtime.data.RRawVector;
 import com.oracle.truffle.r.runtime.data.closures.RClosures;
 import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
 
 @RBuiltin(name = "rank", kind = INTERNAL, parameterNames = {"x", "len", "ties.method"}, behavior = PURE)
@@ -36,18 +49,31 @@ public abstract class Rank extends RBuiltinNode {
 
     @Child private Order.OrderVector1Node orderVector1Node;
     @Child private Order.CmpNode orderCmpNode;
+    private final BranchProfile errorProfile = BranchProfile.create();
 
     private static final Object rho = new Object();
 
     private enum TiesKind {
         AVERAGE,
         MAX,
-        MIN;
+        MIN
     }
 
     @Override
     protected void createCasts(CastBuilder casts) {
-        casts.toInteger(1);
+        // @formatter:off
+        Function<Object, Object> typeFunc = x -> x.getClass().getSimpleName();
+        casts.arg("x").mustBe(abstractVectorValue(), SHOW_CALLER, UNIMPLEMENTED_TYPE_IN_GREATER, typeFunc).
+                mustBe(rawValue().not(), SHOW_CALLER, RError.Message.RAW_SORT);
+        // Note: in the case of no long vector support, when given anything but integer as n, GnuR behaves as if n=1,
+        // we allow ourselves to be bit inconsistent with GnuR in that.
+        casts.arg("len").defaultError(NO_CALLER, INVALID_VALUE, "length(xx)").mustBe(numericValue()).
+                asIntegerVector().
+                mustBe(notEmpty()).
+                findFirst().mustBe(intNA().not().and(gte0()));
+        // Note: we parse ties.methods in the Specialization anyway, so the validation of the value is there
+        casts.arg("ties.method").defaultError(NO_CALLER, INVALID_TIES_FOR_RANK).mustBe(stringValue()).asStringVector().findFirst();
+        // @formatter:on
     }
 
     private Order.OrderVector1Node initOrderVector1() {
@@ -65,30 +91,15 @@ public abstract class Rank extends RBuiltinNode {
     }
 
     @Specialization
-    protected Object rank(RAbstractVector xa, int n, RAbstractStringVector tiesMethod) {
-        if (n < 0 || RRuntime.isNA(n)) {
-            throw RError.error(this, RError.Message.INVALID_ARGUMENT, "length(xx)");
-        }
-        if (xa instanceof RRawVector) {
-            throw RError.error(this, RError.Message.RAW_SORT);
+    protected Object rank(RAbstractVector xa, int inN, String tiesMethod) {
+        int n = inN;
+        if (n > xa.getLength()) {
+            errorProfile.enter();
+            n = xa.getLength();
+            RError.warning(SHOW_CALLER, RANK_LARGE_N);
         }
-        TiesKind tiesKind;
-        switch (tiesMethod.getDataAt(0)) {
-            case "average":
-                tiesKind = TiesKind.AVERAGE;
-                break;
-
-            case "max":
-                tiesKind = TiesKind.MAX;
-                break;
-
-            case "min":
-                tiesKind = TiesKind.MIN;
-                break;
-            default:
-                throw RError.error(this, RError.Message.GENERIC, "invalid ties.method for rank() [should never happen]");
 
-        }
+        TiesKind tiesKind = getTiesKind(tiesMethod);
         int[] ik = null;
         double[] rk = null;
         if (tiesKind == TiesKind.AVERAGE) {
@@ -134,4 +145,18 @@ public abstract class Rank extends RBuiltinNode {
             return RDataFactory.createIntVector(ik, RDataFactory.COMPLETE_VECTOR);
         }
     }
+
+    private TiesKind getTiesKind(String tiesMethod) {
+        switch (tiesMethod) {
+            case "average":
+                return TiesKind.AVERAGE;
+            case "max":
+                return TiesKind.MAX;
+            case "min":
+                return TiesKind.MIN;
+            default:
+                errorProfile.enter();
+                throw RError.error(NO_CALLER, RError.Message.GENERIC, "invalid ties.method for rank() [should never happen]");
+        }
+    }
 }
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java
index 86696ed0d1..9ed3b4fb3b 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java
@@ -509,6 +509,9 @@ public final class RError extends RuntimeException {
         MUST_BE_ONE_BYTE("invalid %s: must be one byte"),
         INVALID_DECIMAL_SEP("invalid decimal separator"),
         INVALID_QUOTE_SYMBOL("invalid quote symbol set"),
+        INVALID_TIES_FOR_RANK("invalid ties.method for rank() [should never happen]"),
+        UNIMPLEMENTED_TYPE_IN_GREATER("unimplemented type '%s' in greater"),
+        RANK_LARGE_N("parameter 'n' is greater than length(x), GnuR output is non-deterministic, FastR will use n=length(x)"),
         // below: not exactly GNU-R message
         TOO_FEW_POSITIVE_PROBABILITY("too few positive probabilities"),
         DOTS_BOUNDS("The ... list does not contain %s elements"),
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
index cbf0b99dd0..72e033df1a 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
@@ -39742,6 +39742,18 @@ Error in Summary.ordered(c(3L, 2L, 1L), FALSE, na.rm = FALSE) :
 Error in as.POSIXct.default(X[[i]], ...) :
   do not know how to convert 'X[[i]]' to class “POSIXct”
 
+##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testArgsCasts
+#.Internal(rank(as.raw(42), 42L, 'max'))
+Error: raw vectors cannot be sorted
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testArgsCasts
+#.Internal(rank(c(1,2), -3L, 'max'))
+Error: invalid 'length(xx)' value
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testArgsCasts
+#.Internal(rank(c(1,2), 2L, 'something'))
+Error: invalid ties.method for rank() [should never happen]
+
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testRank
 #{ rank(c(10,100,100,1000)) }
 [1] 1.0 2.5 2.5 4.0
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java
index 6cef0b8e26..6d2c0d341e 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java
@@ -85,4 +85,11 @@ public class TestBuiltin_rank extends TestBase {
         assertEval("{ rank(c(a=1,b=1,c=3,d=NA,e=3), na.last=NA, ties.method=\"min\") }");
         assertEval("{ rank(c(1000, 100, 100, NA, 1, 20), ties.method=\"first\") }");
     }
+
+    @Test
+    public void testArgsCasts() {
+        assertEval(".Internal(rank(c(1,2), -3L, 'max'))");
+        assertEval(".Internal(rank(c(1,2), 2L, 'something'))");
+        assertEval(".Internal(rank(as.raw(42), 42L, 'max'))");
+    }
 }
-- 
GitLab