From 7e9da700f853bdf96feb4721a39897fcfc8f5afe Mon Sep 17 00:00:00 2001 From: Adam Welc <adam.welc@oracle.com> Date: Sun, 28 Aug 2016 14:05:35 -0700 Subject: [PATCH] Rewritten parameter casts for the unique builtin. --- .../truffle/r/nodes/builtin/base/Unique.java | 42 +++++++++++++------ .../r/test/builtins/TestBuiltin_unique.java | 5 +++ 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java index 74c3cda1ad..ea894dcbba 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Unique.java @@ -22,6 +22,7 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; @@ -32,11 +33,12 @@ import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.Utils; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -54,9 +56,9 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -// Implements default S3 method -@RBuiltin(name = "unique", kind = INTERNAL, parameterNames = {"x", "incomparables", "fromLast", "nmax", "..."}, behavior = PURE) +@RBuiltin(name = "unique", kind = INTERNAL, parameterNames = {"x", "incomparables", "fromLast", "nmax"}, behavior = PURE) // TODO A more efficient implementation is in order; GNU R uses hash tables so perhaps we should // consider using one of the existing libraries that offer hash table implementations for primitive // types @@ -66,15 +68,31 @@ public abstract class Unique extends RBuiltinNode { private final ConditionProfile bigProfile = ConditionProfile.createBinaryProfile(); + @Override + protected void createCasts(CastBuilder casts) { + // these are similar to those in DuplicatedFunctions.java + casts.arg("x").mustBe(nullValue().or(abstractVectorValue()), RError.SHOW_CALLER, RError.Message.APPLIES_TO_VECTORS, + "unique()").mapIf(nullValue().not(), asVector()); + // not much more can be done for incomparables as it is either a vector of incomparable + // values or a (single) logical value + // TODO: coercion error must be handled by specialization as it depends on type of x (much + // like in duplicated) + casts.arg("incomparables").asVector(true); + casts.arg("fromLast").asLogicalVector().findFirst(RRuntime.LOGICAL_FALSE); + // currently not supported and not tested, but NA is a correct value (the same for empty + // vectors) whereas 0 is not (throws an error) + casts.arg("nmax").asIntegerVector().findFirst(RRuntime.INT_NA); + } + @SuppressWarnings("unused") @Specialization - protected RNull doUnique(RNull vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RNull doUnique(RNull vec, RAbstractVector incomparables, byte fromLast, int nmax) { return vec; } @SuppressWarnings("unused") @Specialization - protected RStringVector doUnique(RAbstractStringVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RStringVector doUnique(RAbstractStringVector vec, RAbstractVector incomparables, byte fromLast, int nmax) { if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { Utils.NonRecursiveHashSet<String> set = new Utils.NonRecursiveHashSet<>(vec.getLength()); String[] data = new String[vec.getLength()]; @@ -230,7 +248,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization - protected RIntVector doUnique(RAbstractIntVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RIntVector doUnique(RAbstractIntVector vec, RAbstractVector incomparables, byte fromLast, int nmax) { if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { NonRecursiveHashSetInt set = new NonRecursiveHashSetInt(); int[] data = new int[16]; @@ -259,7 +277,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization(guards = "lengthOne(list)") - protected RList doUniqueL1(RList list, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RList doUniqueL1(RList list, RAbstractVector incomparables, byte fromLast, int nmax) { return (RList) list.copyDropAttributes(); } @@ -276,7 +294,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization(guards = "!lengthOne(list)") @TruffleBoundary - protected RList doUnique(RList list, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RList doUnique(RList list, RAbstractVector incomparables, byte fromLast, int nmax) { /* * Brute force, as manual says: Using this for lists is potentially slow, especially if the * elements are not atomic vectors (see vector) or differ only in their attributes. In the @@ -355,7 +373,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization - protected RDoubleVector doUnique(RAbstractDoubleVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RDoubleVector doUnique(RAbstractDoubleVector vec, RAbstractVector incomparables, byte fromLast, int nmax) { if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { Utils.NonRecursiveHashSetDouble set = new Utils.NonRecursiveHashSetDouble(vec.getLength()); double[] data = new double[vec.getLength()]; @@ -381,7 +399,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization - protected RLogicalVector doUnique(RAbstractLogicalVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RLogicalVector doUnique(RAbstractLogicalVector vec, RAbstractVector incomparables, byte fromLast, int nmax) { ByteArray dataList = new ByteArray(vec.getLength()); for (int i = 0; i < vec.getLength(); i++) { byte val = vec.getDataAt(i); @@ -394,7 +412,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization - protected RComplexVector doUnique(RAbstractComplexVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RComplexVector doUnique(RAbstractComplexVector vec, RAbstractVector incomparables, byte fromLast, int nmax) { if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { Utils.NonRecursiveHashSet<RComplex> set = new Utils.NonRecursiveHashSet<>(vec.getLength()); double[] data = new double[vec.getLength() * 2]; @@ -421,7 +439,7 @@ public abstract class Unique extends RBuiltinNode { @SuppressWarnings("unused") @Specialization - protected RRawVector doUnique(RAbstractRawVector vec, byte incomparables, byte fromLast, Object nmax, RArgsValuesAndNames vararg) { + protected RRawVector doUnique(RAbstractRawVector vec, RAbstractVector incomparables, byte fromLast, int nmax) { if (bigProfile.profile(vec.getLength() * (long) vec.getLength() > BIG_THRESHOLD)) { Utils.NonRecursiveHashSet<RRaw> set = new Utils.NonRecursiveHashSet<>(vec.getLength()); byte[] data = new byte[vec.getLength()]; diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java index ff067fd228..fb84a6827c 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_unique.java @@ -189,5 +189,10 @@ public class TestBuiltin_unique extends TestBase { @Test public void testUnique() { assertEval("{x<-factor(c(\"a\", \"b\", \"a\")); unique(x) }"); + + assertEval("{ x<-quote(f(7, 42)); unique(x) }"); + assertEval("{ x<-function() 42; unique(x) }"); + assertEval(Ignored.Unknown, "{ unique(c(1,2,1), incomparables=function() 42) }"); + } } -- GitLab