From 192228708c188127c82ec63ce39150e827c3f79b Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Fri, 26 Aug 2016 15:15:17 +0200 Subject: [PATCH] Rank: converted to cast pipeline --- .../truffle/r/nodes/builtin/base/Rank.java | 79 ++++++++++++------- .../com/oracle/truffle/r/runtime/RError.java | 3 + .../truffle/r/test/ExpectedTestOutput.test | 12 +++ .../r/test/builtins/TestBuiltin_rank.java | 7 ++ 4 files changed, 74 insertions(+), 27 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java index 808450ebf3..022d7ec2dc 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Rank.java @@ -12,23 +12,36 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.intNA; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notEmpty; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.rawValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue; +import static com.oracle.truffle.r.runtime.RError.NO_CALLER; +import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; +import static com.oracle.truffle.r.runtime.RError.Message.INVALID_TIES_FOR_RANK; +import static com.oracle.truffle.r.runtime.RError.Message.INVALID_VALUE; +import static com.oracle.truffle.r.runtime.RError.Message.RANK_LARGE_N; +import static com.oracle.truffle.r.runtime.RError.Message.UNIMPLEMENTED_TYPE_IN_GREATER; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; +import java.util.function.Function; + import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.base.OrderNodeGen.CmpNodeGen; import com.oracle.truffle.r.nodes.builtin.base.OrderNodeGen.OrderVector1NodeGen; import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RIntVector; -import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.closures.RClosures; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; @RBuiltin(name = "rank", kind = INTERNAL, parameterNames = {"x", "len", "ties.method"}, behavior = PURE) @@ -36,18 +49,31 @@ public abstract class Rank extends RBuiltinNode { @Child private Order.OrderVector1Node orderVector1Node; @Child private Order.CmpNode orderCmpNode; + private final BranchProfile errorProfile = BranchProfile.create(); private static final Object rho = new Object(); private enum TiesKind { AVERAGE, MAX, - MIN; + MIN } @Override protected void createCasts(CastBuilder casts) { - casts.toInteger(1); + // @formatter:off + Function<Object, Object> typeFunc = x -> x.getClass().getSimpleName(); + casts.arg("x").mustBe(abstractVectorValue(), SHOW_CALLER, UNIMPLEMENTED_TYPE_IN_GREATER, typeFunc). + mustBe(rawValue().not(), SHOW_CALLER, RError.Message.RAW_SORT); + // Note: in the case of no long vector support, when given anything but integer as n, GnuR behaves as if n=1, + // we allow ourselves to be bit inconsistent with GnuR in that. + casts.arg("len").defaultError(NO_CALLER, INVALID_VALUE, "length(xx)").mustBe(numericValue()). + asIntegerVector(). + mustBe(notEmpty()). + findFirst().mustBe(intNA().not().and(gte0())); + // Note: we parse ties.methods in the Specialization anyway, so the validation of the value is there + casts.arg("ties.method").defaultError(NO_CALLER, INVALID_TIES_FOR_RANK).mustBe(stringValue()).asStringVector().findFirst(); + // @formatter:on } private Order.OrderVector1Node initOrderVector1() { @@ -65,30 +91,15 @@ public abstract class Rank extends RBuiltinNode { } @Specialization - protected Object rank(RAbstractVector xa, int n, RAbstractStringVector tiesMethod) { - if (n < 0 || RRuntime.isNA(n)) { - throw RError.error(this, RError.Message.INVALID_ARGUMENT, "length(xx)"); - } - if (xa instanceof RRawVector) { - throw RError.error(this, RError.Message.RAW_SORT); + protected Object rank(RAbstractVector xa, int inN, String tiesMethod) { + int n = inN; + if (n > xa.getLength()) { + errorProfile.enter(); + n = xa.getLength(); + RError.warning(SHOW_CALLER, RANK_LARGE_N); } - TiesKind tiesKind; - switch (tiesMethod.getDataAt(0)) { - case "average": - tiesKind = TiesKind.AVERAGE; - break; - - case "max": - tiesKind = TiesKind.MAX; - break; - - case "min": - tiesKind = TiesKind.MIN; - break; - default: - throw RError.error(this, RError.Message.GENERIC, "invalid ties.method for rank() [should never happen]"); - } + TiesKind tiesKind = getTiesKind(tiesMethod); int[] ik = null; double[] rk = null; if (tiesKind == TiesKind.AVERAGE) { @@ -134,4 +145,18 @@ public abstract class Rank extends RBuiltinNode { return RDataFactory.createIntVector(ik, RDataFactory.COMPLETE_VECTOR); } } + + private TiesKind getTiesKind(String tiesMethod) { + switch (tiesMethod) { + case "average": + return TiesKind.AVERAGE; + case "max": + return TiesKind.MAX; + case "min": + return TiesKind.MIN; + default: + errorProfile.enter(); + throw RError.error(NO_CALLER, RError.Message.GENERIC, "invalid ties.method for rank() [should never happen]"); + } + } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java index 86696ed0d1..9ed3b4fb3b 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java @@ -509,6 +509,9 @@ public final class RError extends RuntimeException { MUST_BE_ONE_BYTE("invalid %s: must be one byte"), INVALID_DECIMAL_SEP("invalid decimal separator"), INVALID_QUOTE_SYMBOL("invalid quote symbol set"), + INVALID_TIES_FOR_RANK("invalid ties.method for rank() [should never happen]"), + UNIMPLEMENTED_TYPE_IN_GREATER("unimplemented type '%s' in greater"), + RANK_LARGE_N("parameter 'n' is greater than length(x), GnuR output is non-deterministic, FastR will use n=length(x)"), // below: not exactly GNU-R message TOO_FEW_POSITIVE_PROBABILITY("too few positive probabilities"), DOTS_BOUNDS("The ... list does not contain %s elements"), diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index cbf0b99dd0..72e033df1a 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -39742,6 +39742,18 @@ Error in Summary.ordered(c(3L, 2L, 1L), FALSE, na.rm = FALSE) : Error in as.POSIXct.default(X[[i]], ...) : do not know how to convert 'X[[i]]' to class “POSIXct” +##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testArgsCasts +#.Internal(rank(as.raw(42), 42L, 'max')) +Error: raw vectors cannot be sorted + +##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testArgsCasts +#.Internal(rank(c(1,2), -3L, 'max')) +Error: invalid 'length(xx)' value + +##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testArgsCasts +#.Internal(rank(c(1,2), 2L, 'something')) +Error: invalid ties.method for rank() [should never happen] + ##com.oracle.truffle.r.test.builtins.TestBuiltin_rank.testRank #{ rank(c(10,100,100,1000)) } [1] 1.0 2.5 2.5 4.0 diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java index 6cef0b8e26..6d2c0d341e 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rank.java @@ -85,4 +85,11 @@ public class TestBuiltin_rank extends TestBase { assertEval("{ rank(c(a=1,b=1,c=3,d=NA,e=3), na.last=NA, ties.method=\"min\") }"); assertEval("{ rank(c(1000, 100, 100, NA, 1, 20), ties.method=\"first\") }"); } + + @Test + public void testArgsCasts() { + assertEval(".Internal(rank(c(1,2), -3L, 'max'))"); + assertEval(".Internal(rank(c(1,2), 2L, 'something'))"); + assertEval(".Internal(rank(as.raw(42), 42L, 'max'))"); + } } -- GitLab