diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java index 65de607dba4d9a9dbec7644b0f3b59d72d49bd5b..f4ca1e913a23c8c978744c3bbdde9a521581b57c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java @@ -22,6 +22,7 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; @@ -30,6 +31,7 @@ import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.profiles.ValueProfile; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.control.RLengthNode; import com.oracle.truffle.r.runtime.RRuntime; @@ -47,12 +49,31 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.ops.na.NACheck; -@RBuiltin(name = "anyNA", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@RBuiltin(name = "anyNA", kind = PRIMITIVE, parameterNames = {"x", "recursive"}, dispatch = INTERNAL_GENERIC, behavior = PURE) public abstract class AnyNA extends RBuiltinNode { private final NACheck naCheck = NACheck.create(); - public abstract byte execute(VirtualFrame frame, Object value); + private final boolean recursionAllowed; + + protected AnyNA(boolean recursionAllowed) { + this.recursionAllowed = recursionAllowed; + } + + protected AnyNA() { + this(true); + } + + protected boolean isRecursionAllowed() { + return recursionAllowed; + } + + public abstract byte execute(VirtualFrame frame, Object value, boolean recursive); + + @Override + protected void createCasts(CastBuilder casts) { + casts.arg("recursive").asLogicalVector().findFirst(RRuntime.LOGICAL_FALSE).map(toBoolean()); + } private static byte doScalar(boolean isNA) { return RRuntime.asLogical(isNA); @@ -74,84 +95,86 @@ public abstract class AnyNA extends RBuiltinNode { } @Specialization - protected byte isNA(byte value) { + protected byte isNA(byte value, @SuppressWarnings("unused") boolean recursive) { return doScalar(RRuntime.isNA(value)); } @Specialization - protected byte isNA(int value) { + protected byte isNA(int value, @SuppressWarnings("unused") boolean recursive) { return doScalar(RRuntime.isNA(value)); } @Specialization - protected byte isNA(double value) { + protected byte isNA(double value, @SuppressWarnings("unused") boolean recursive) { return doScalar(RRuntime.isNAorNaN(value)); } @Specialization - protected byte isNA(RComplex value) { + protected byte isNA(RComplex value, @SuppressWarnings("unused") boolean recursive) { return doScalar(RRuntime.isNA(value)); } @Specialization - protected byte isNA(String value) { + protected byte isNA(String value, @SuppressWarnings("unused") boolean recursive) { return doScalar(RRuntime.isNA(value)); } @Specialization - protected byte isNA(@SuppressWarnings("unused") RRaw value) { + @SuppressWarnings("unused") + protected byte isNA(RRaw value, boolean recursive) { return doScalar(false); } @Specialization - protected byte isNA(@SuppressWarnings("unused") RNull value) { + protected byte isNA(@SuppressWarnings("unused") RNull value, @SuppressWarnings("unused") boolean recursive) { return doScalar(false); } @Specialization - protected byte isNA(RAbstractIntVector vector) { + protected byte isNA(RAbstractIntVector vector, @SuppressWarnings("unused") boolean recursive) { return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); } @Specialization - protected byte isNA(RAbstractDoubleVector vector) { + protected byte isNA(RAbstractDoubleVector vector, @SuppressWarnings("unused") boolean recursive) { // since return doVector(vector, (v, i) -> naCheck.checkNAorNaN(v.getDataAt(i))); } @Specialization - protected byte isNA(RAbstractComplexVector vector) { + protected byte isNA(RAbstractComplexVector vector, @SuppressWarnings("unused") boolean recursive) { return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); } @Specialization - protected byte isNA(RAbstractStringVector vector) { + protected byte isNA(RAbstractStringVector vector, @SuppressWarnings("unused") boolean recursive) { return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); } @Specialization - protected byte isNA(RAbstractLogicalVector vector) { + protected byte isNA(RAbstractLogicalVector vector, @SuppressWarnings("unused") boolean recursive) { return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); } @Specialization - protected byte isNA(@SuppressWarnings("unused") RAbstractRawVector vector) { + protected byte isNA(@SuppressWarnings("unused") RAbstractRawVector vector, @SuppressWarnings("unused") boolean recursive) { return doScalar(false); } - protected AnyNA createRecursive() { - return AnyNANodeGen.create(null); + protected AnyNA createRecursive(boolean recursive) { + return AnyNANodeGen.create(recursive, null); } - @Specialization - protected byte isNA(VirtualFrame frame, RList list, // - @Cached("createRecursive()") AnyNA recursive, // - @Cached("createClassProfile()") ValueProfile elementProfile, // + @Specialization(guards = "isRecursionAllowed()") + protected byte isNA(VirtualFrame frame, RList list, boolean recursive, + @Cached("createRecursive(recursive)") AnyNA recursiveNode, + @Cached("createClassProfile()") ValueProfile elementProfile, @Cached("create()") RLengthNode length) { + for (int i = 0; i < list.getLength(); i++) { Object value = elementProfile.profile(list.getDataAt(i)); - if (length.executeInteger(frame, value) == 1) { - byte result = recursive.execute(frame, value); + if (length.executeInteger(frame, value) > 0) { + byte result = recursiveNode.execute(frame, value, recursive); if (result == RRuntime.LOGICAL_TRUE) { return RRuntime.LOGICAL_TRUE; } @@ -159,4 +182,10 @@ public abstract class AnyNA extends RBuiltinNode { } return RRuntime.LOGICAL_FALSE; } + + @Specialization(guards = "!isRecursionAllowed()") + @SuppressWarnings("unused") + protected byte isNA(VirtualFrame frame, RList list, boolean recursive) { + return RRuntime.LOGICAL_FALSE; + } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Arg.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Arg.java index 618c7c4b5f99ed26bab4713d06855088ad30c6ab..1e7928692829dd0675fe355bfc254895ce83495e 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Arg.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Arg.java @@ -22,22 +22,138 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; import static com.oracle.truffle.r.runtime.RDispatch.COMPLEX_GROUP_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; -import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import com.oracle.truffle.r.runtime.data.RComplex; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; +import com.oracle.truffle.r.runtime.ops.na.NACheck; @RBuiltin(name = "Arg", kind = PRIMITIVE, parameterNames = {"z"}, dispatch = COMPLEX_GROUP_GENERIC, behavior = PURE) public abstract class Arg extends RBuiltinNode { + private final ConditionProfile signumProfile = ConditionProfile.createBinaryProfile(); + private final NACheck naCheck = NACheck.create(); + + @Override + protected void createCasts(CastBuilder casts) { + casts.arg("z").mustBe(numericValue().or(complexValue()), RError.Message.NON_NUMERIC_ARGUMENT_FUNCTION); + } + + @Specialization + protected double arg(double z) { + naCheck.enable(z); + if (naCheck.check(z)) { + return RRuntime.DOUBLE_NA; + } + if (signumProfile.profile(z >= 0)) { + return 0; + } else { + return Math.PI; + } + } + @Specialization - @TruffleBoundary - protected Object im(@SuppressWarnings("unused") Object value) { - throw RInternalError.unimplemented(); + protected RAbstractDoubleVector arg(RAbstractDoubleVector v) { + double[] result = new double[v.getLength()]; + + naCheck.enable(v); + for (int i = 0; i < v.getLength(); i++) { + double z = v.getDataAt(i); + if (naCheck.check(z)) { + result[i] = RRuntime.DOUBLE_NA; + } else { + result[i] = z >= 0 ? 0 : Math.PI; + } + } + + return RDataFactory.createDoubleVector(result, v.isComplete()); + } + + @Specialization + protected double arg(int z) { + naCheck.enable(z); + if (naCheck.check(z)) { + return RRuntime.DOUBLE_NA; + } + if (signumProfile.profile(z >= 0)) { + return 0; + } else { + return Math.PI; + } + } + + @Specialization + protected RAbstractDoubleVector arg(RAbstractIntVector v) { + double[] result = new double[v.getLength()]; + + naCheck.enable(v); + for (int i = 0; i < v.getLength(); i++) { + int z = v.getDataAt(i); + if (naCheck.check(z)) { + result[i] = RRuntime.DOUBLE_NA; + } else { + result[i] = z >= 0 ? 0 : Math.PI; + } + } + + return RDataFactory.createDoubleVector(result, v.isComplete()); + } + + @Specialization + protected double arg(byte z) { + naCheck.enable(z); + if (naCheck.check(z)) { + return RRuntime.DOUBLE_NA; + } + return 0; + } + + @Specialization + protected RAbstractDoubleVector arg(RAbstractLogicalVector v) { + double[] result = new double[v.getLength()]; + + naCheck.enable(v); + for (int i = 0; i < v.getLength(); i++) { + int z = v.getDataAt(i); + if (naCheck.check(z)) { + result[i] = RRuntime.DOUBLE_NA; + } else { + result[i] = 0; + } + } + + return RDataFactory.createDoubleVector(result, v.isComplete()); + } + + @Specialization + protected RAbstractDoubleVector arg(RAbstractComplexVector v) { + double[] result = new double[v.getLength()]; + + naCheck.enable(v); + for (int i = 0; i < v.getLength(); i++) { + RComplex z = v.getDataAt(i); + if (naCheck.check(z)) { + result[i] = RRuntime.DOUBLE_NA; + } else { + result[i] = Math.atan2(z.getImaginaryPart(), z.getRealPart()); + } + } + + return RDataFactory.createDoubleVector(result, v.isComplete()); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index bff55d8caa2d1860e270b43fcccca7e49b816efd..02ae51b88c51d7a74fa1cc7fee9daf0d4c87f480 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -2620,6 +2620,22 @@ In anyDuplicated.default(c(1L, 2L, 1L, 1L, 3L, 2L), incomparables = "cat") : #argv <- list(c(1.81566026854212e-304, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0));do.call('anyNA', argv) [1] FALSE +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA2 +#anyNA(list(list(4,5,NA), 3), recursive=TRUE) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA2 +#anyNA(list(list(c(NA)),c(1)), recursive=FALSE) +[1] FALSE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA2 +#anyNA(list(list(c(NA)),c(1)), recursive=TRUE) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA2 +#anyNA(list(list(c(NA)),c(1)), recursive=c(FALSE,TRUE)) +[1] FALSE + ##com.oracle.truffle.r.test.builtins.TestBuiltin_aperm.testAperm #{ a = array(1:24,c(2,3,4)); b = aperm(a); c(dim(b)[1],dim(b)[2],dim(b)[3]) } [1] 4 3 2 diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java index 517851cc383fcab35fd0857ef817c1c24e50c9bf..cea95545ce24d229ee5943fa53220345c2d9141b 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java @@ -22,4 +22,13 @@ public class TestBuiltin_anyNA extends TestBase { public void testanyNA1() { assertEval("argv <- list(c(1.81566026854212e-304, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0));do.call('anyNA', argv)"); } + + @Test + public void testanyNA2() { + assertEval("anyNA(list(list(c(NA)),c(1)), recursive=TRUE)"); + assertEval("anyNA(list(list(c(NA)),c(1)), recursive=FALSE)"); + assertEval("anyNA(list(list(c(NA)),c(1)), recursive=c(FALSE,TRUE))"); + assertEval("anyNA(list(list(4,5,NA), 3), recursive=TRUE)"); + } + }