From f96a88338a86150c3366fb3a01c307d95af45914 Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Thu, 16 Nov 2017 13:25:59 +0100 Subject: [PATCH] refactor anyNA to use VectorAccess, more correct implementation --- .../truffle/r/nodes/builtin/base/AnyNA.java | 162 ++++++++++-------- .../truffle/r/test/ExpectedTestOutput.test | 47 +++++ .../r/test/builtins/TestBuiltin_anyNA.java | 11 ++ 3 files changed, 149 insertions(+), 71 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java index 515d57eec9..8ee5500175 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java @@ -27,6 +27,7 @@ import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_SUMMARY; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ValueProfile; @@ -39,19 +40,23 @@ import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RRaw; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -import com.oracle.truffle.r.runtime.ops.na.NACheck; +import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; @RBuiltin(name = "anyNA", kind = PRIMITIVE, parameterNames = {"x", "recursive"}, dispatch = INTERNAL_GENERIC, behavior = PURE_SUMMARY) public abstract class AnyNA extends RBuiltinNode.Arg2 { - private final NACheck naCheck = NACheck.create(); + // true if this is the first recursive level + protected final boolean isRecursive; + + protected AnyNA() { + this.isRecursive = false; + } + + protected AnyNA(boolean isRecursive) { + this.isRecursive = isRecursive; + } public abstract byte execute(Object value, boolean recursive); @@ -65,117 +70,132 @@ public abstract class AnyNA extends RBuiltinNode.Arg2 { return new Object[]{RMissing.instance, RRuntime.LOGICAL_FALSE}; } - private static byte doScalar(boolean isNA) { - return RRuntime.asLogical(isNA); - } - - @FunctionalInterface - private interface VectorIndexPredicate<T extends RAbstractVector> { - boolean apply(T vector, int index); - } - - private <T extends RAbstractVector> byte doVector(T vector, VectorIndexPredicate<T> predicate) { - naCheck.enable(vector); - for (int i = 0; i < vector.getLength(); i++) { - if (predicate.apply(vector, i)) { - return RRuntime.LOGICAL_TRUE; - } - } - return RRuntime.LOGICAL_FALSE; - } - @Specialization protected byte isNA(byte value, @SuppressWarnings("unused") boolean recursive) { - return doScalar(RRuntime.isNA(value)); + return RRuntime.asLogical(RRuntime.isNA(value)); } @Specialization protected byte isNA(int value, @SuppressWarnings("unused") boolean recursive) { - return doScalar(RRuntime.isNA(value)); + return RRuntime.asLogical(RRuntime.isNA(value)); } @Specialization protected byte isNA(double value, @SuppressWarnings("unused") boolean recursive) { - return doScalar(RRuntime.isNAorNaN(value)); + return RRuntime.asLogical(RRuntime.isNAorNaN(value)); } @Specialization protected byte isNA(RComplex value, @SuppressWarnings("unused") boolean recursive) { - return doScalar(RRuntime.isNA(value)); + return RRuntime.asLogical(RRuntime.isNA(value)); } @Specialization protected byte isNA(String value, @SuppressWarnings("unused") boolean recursive) { - return doScalar(RRuntime.isNA(value)); + return RRuntime.asLogical(RRuntime.isNA(value)); } @Specialization @SuppressWarnings("unused") protected byte isNA(RRaw value, boolean recursive) { - return doScalar(false); + return RRuntime.LOGICAL_FALSE; } @Specialization protected byte isNA(@SuppressWarnings("unused") RNull value, @SuppressWarnings("unused") boolean recursive) { - return doScalar(false); - } - - @Specialization - protected byte isNA(RAbstractIntVector vector, @SuppressWarnings("unused") boolean recursive) { - return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); - } - - @Specialization - protected byte isNA(RAbstractDoubleVector vector, @SuppressWarnings("unused") boolean recursive) { - // since - return doVector(vector, (v, i) -> naCheck.checkNAorNaN(v.getDataAt(i))); + return RRuntime.LOGICAL_FALSE; } - @Specialization - protected byte isNA(RAbstractComplexVector vector, @SuppressWarnings("unused") boolean recursive) { - return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); + @Specialization(guards = "xAccess.supports(x)") + protected byte anyNACached(RAbstractAtomicVector x, @SuppressWarnings("unused") boolean recursive, + @Cached("x.access()") VectorAccess xAccess) { + switch (xAccess.getType()) { + case Logical: + case Integer: + case Character: + // shortcut when we know there's no NAs + if (!x.isComplete()) { + try (SequentialIterator iter = xAccess.access(x)) { + while (xAccess.next(iter)) { + if (xAccess.isNA(iter)) { + return RRuntime.LOGICAL_TRUE; + } + } + } + } + break; + case Raw: + return RRuntime.LOGICAL_FALSE; + case Double: + // we need to check for NaNs + try (SequentialIterator iter = xAccess.access(x)) { + while (xAccess.next(iter)) { + if (xAccess.na.checkNAorNaN(xAccess.getDouble(iter))) { + return RRuntime.LOGICAL_TRUE; + } + } + } + break; + case Complex: + // we need to check for NaNs + try (SequentialIterator iter = xAccess.access(x)) { + while (xAccess.next(iter)) { + if (xAccess.na.checkNAorNaN(xAccess.getComplexR(iter)) || xAccess.na.checkNAorNaN(xAccess.getComplexR(iter))) { + return RRuntime.LOGICAL_TRUE; + } + } + } + break; + } + return RRuntime.LOGICAL_FALSE; } - @Specialization - protected byte isNA(RAbstractStringVector vector, @SuppressWarnings("unused") boolean recursive) { - return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); + @Specialization(replaces = "anyNACached") + protected byte anyNAGeneric(RAbstractAtomicVector x, boolean recursive) { + return anyNACached(x, recursive, x.slowPathAccess()); } - @Specialization - protected byte isNA(RAbstractLogicalVector vector, @SuppressWarnings("unused") boolean recursive) { - return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i))); + protected AnyNA createRecursive() { + return AnyNANodeGen.create(true); } - @Specialization - protected byte isNA(@SuppressWarnings("unused") RAbstractRawVector vector, @SuppressWarnings("unused") boolean recursive) { - return doScalar(false); + @Specialization(guards = {"isRecursive", "recursive == cachedRecursive"}) + protected byte isNARecursive(RList list, boolean recursive, + @Cached("recursive") boolean cachedRecursive, + @Cached("createClassProfile()") ValueProfile elementProfile, + @Cached("create()") RLengthNode length) { + if (cachedRecursive) { + for (int i = 0; i < list.getLength(); i++) { + Object value = elementProfile.profile(list.getDataAt(i)); + if (length.executeInteger(value) > 0) { + if (recursive(recursive, value) == RRuntime.LOGICAL_TRUE) { + return RRuntime.LOGICAL_TRUE; + } + } + } + } + return RRuntime.LOGICAL_FALSE; } - protected AnyNA createRecursive() { - return AnyNANodeGen.create(); + @TruffleBoundary + private byte recursive(boolean recursive, Object value) { + return execute(value, recursive); } - @Specialization(guards = "recursive") + @Specialization(guards = {"!isRecursive", "recursive == cachedRecursive"}) protected byte isNA(RList list, boolean recursive, + @Cached("recursive") boolean cachedRecursive, @Cached("createRecursive()") AnyNA recursiveNode, @Cached("createClassProfile()") ValueProfile elementProfile, @Cached("create()") RLengthNode length) { - for (int i = 0; i < list.getLength(); i++) { Object value = elementProfile.profile(list.getDataAt(i)); - if (length.executeInteger(value) > 0) { - byte result = recursiveNode.execute(value, recursive); - if (result == RRuntime.LOGICAL_TRUE) { + if (cachedRecursive || length.executeInteger(value) == 1) { + if (recursiveNode.execute(value, recursive) == RRuntime.LOGICAL_TRUE) { return RRuntime.LOGICAL_TRUE; } } } return RRuntime.LOGICAL_FALSE; } - - @Specialization(guards = "!recursive") - @SuppressWarnings("unused") - protected byte isNA(RList list, boolean recursive) { - return RRuntime.LOGICAL_FALSE; - } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 8aef4723e6..a2c4cbdf10 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -3221,6 +3221,38 @@ In anyDuplicated.default(c(1L, 2L, 1L, 1L, 3L, 2L), incomparables = "cat") : #anyNA(c(1, NA, 3), recursive = TRUE) [1] TRUE +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(1, NA)) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(1, NA), recursive = TRUE) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = NA)) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = NA), recursive = TRUE) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = NA, b = 'a')) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = NA, b = 'a'), recursive = TRUE) +[1] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = c('asdf', NA), b = 'a')) +[1] FALSE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = c('asdf', NA), b = 'a'), recursive = TRUE) +[1] TRUE + ##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# #anyNA(list(a = c(1, 2, 3), b = 'a')) [1] FALSE @@ -3229,6 +3261,10 @@ In anyDuplicated.default(c(1L, 2L, 1L, 1L, 3L, 2L), incomparables = "cat") : #anyNA(list(a = c(1, 2, 3), b = 'a'), recursive = TRUE) [1] FALSE +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = c(1, 2, 3), b = list(NA, 'a')), recursive = TRUE) +[1] TRUE + ##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# #anyNA(list(a = c(1, NA, 3), b = 'a')) [1] FALSE @@ -3237,6 +3273,14 @@ In anyDuplicated.default(c(1L, 2L, 1L, 1L, 3L, 2L), incomparables = "cat") : #anyNA(list(a = c(1, NA, 3), b = 'a'), recursive = TRUE) [1] TRUE +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = c(NA, 3), b = 'a')) +[1] FALSE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3# +#anyNA(list(a = c(NA, 3), b = 'a'), recursive = TRUE) +[1] TRUE + ##com.oracle.truffle.r.test.builtins.TestBuiltin_aperm.testAperm# #{ a = array(1:24,c(2,3,4)); b = aperm(a); c(dim(b)[1],dim(b)[2],dim(b)[3]) } [1] 4 3 2 @@ -9955,6 +9999,9 @@ Error in attributes(x) <- 44 : attributes must be a list or NULL ##com.oracle.truffle.r.test.builtins.TestBuiltin_attributesassign.testArgsCasts# #x <- 42; attributes(x) <- NULL +##com.oracle.truffle.r.test.builtins.TestBuiltin_attributesassign.testArgsCasts# +#x <- 42; attributes(x) <- list() + ##com.oracle.truffle.r.test.builtins.TestBuiltin_attributesassign.testattributesassign1#Ignored.ImplementationError# #argv <- list(NULL, NULL);`attributes<-`(argv[[1]],argv[[2]]); NULL diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java index faeee96e06..a428859a15 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java @@ -38,7 +38,18 @@ public class TestBuiltin_anyNA extends TestBase { assertEval("anyNA(c(1, NA, 3), recursive = TRUE)"); assertEval("anyNA(list(a = c(1, 2, 3), b = 'a'))"); assertEval("anyNA(list(a = c(1, NA, 3), b = 'a'))"); + assertEval("anyNA(list(a = c('asdf', NA), b = 'a'))"); + assertEval("anyNA(list(a = c(NA, 3), b = 'a'))"); + assertEval("anyNA(list(a = NA, b = 'a'))"); + assertEval("anyNA(list(a = NA))"); + assertEval("anyNA(list(1, NA))"); + assertEval("anyNA(list(a = c('asdf', NA), b = 'a'), recursive = TRUE)"); + assertEval("anyNA(list(a = c(NA, 3), b = 'a'), recursive = TRUE)"); + assertEval("anyNA(list(a = NA, b = 'a'), recursive = TRUE)"); + assertEval("anyNA(list(a = NA), recursive = TRUE)"); + assertEval("anyNA(list(1, NA), recursive = TRUE)"); assertEval("anyNA(list(a = c(1, 2, 3), b = 'a'), recursive = TRUE)"); assertEval("anyNA(list(a = c(1, NA, 3), b = 'a'), recursive = TRUE)"); + assertEval("anyNA(list(a = c(1, 2, 3), b = list(NA, 'a')), recursive = TRUE)"); } } -- GitLab