From f96a88338a86150c3366fb3a01c307d95af45914 Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Thu, 16 Nov 2017 13:25:59 +0100
Subject: [PATCH] refactor anyNA to use VectorAccess, more correct
 implementation

---
 .../truffle/r/nodes/builtin/base/AnyNA.java   | 162 ++++++++++--------
 .../truffle/r/test/ExpectedTestOutput.test    |  47 +++++
 .../r/test/builtins/TestBuiltin_anyNA.java    |  11 ++
 3 files changed, 149 insertions(+), 71 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java
index 515d57eec9..8ee5500175 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AnyNA.java
@@ -27,6 +27,7 @@ import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_SUMMARY;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
 
+import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
 import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.profiles.ValueProfile;
@@ -39,19 +40,23 @@ import com.oracle.truffle.r.runtime.data.RList;
 import com.oracle.truffle.r.runtime.data.RMissing;
 import com.oracle.truffle.r.runtime.data.RNull;
 import com.oracle.truffle.r.runtime.data.RRaw;
-import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
-import com.oracle.truffle.r.runtime.ops.na.NACheck;
+import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector;
+import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
+import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
 
 @RBuiltin(name = "anyNA", kind = PRIMITIVE, parameterNames = {"x", "recursive"}, dispatch = INTERNAL_GENERIC, behavior = PURE_SUMMARY)
 public abstract class AnyNA extends RBuiltinNode.Arg2 {
 
-    private final NACheck naCheck = NACheck.create();
+    // true if this is the first recursive level
+    protected final boolean isRecursive;
+
+    protected AnyNA() {
+        this.isRecursive = false;
+    }
+
+    protected AnyNA(boolean isRecursive) {
+        this.isRecursive = isRecursive;
+    }
 
     public abstract byte execute(Object value, boolean recursive);
 
@@ -65,117 +70,132 @@ public abstract class AnyNA extends RBuiltinNode.Arg2 {
         return new Object[]{RMissing.instance, RRuntime.LOGICAL_FALSE};
     }
 
-    private static byte doScalar(boolean isNA) {
-        return RRuntime.asLogical(isNA);
-    }
-
-    @FunctionalInterface
-    private interface VectorIndexPredicate<T extends RAbstractVector> {
-        boolean apply(T vector, int index);
-    }
-
-    private <T extends RAbstractVector> byte doVector(T vector, VectorIndexPredicate<T> predicate) {
-        naCheck.enable(vector);
-        for (int i = 0; i < vector.getLength(); i++) {
-            if (predicate.apply(vector, i)) {
-                return RRuntime.LOGICAL_TRUE;
-            }
-        }
-        return RRuntime.LOGICAL_FALSE;
-    }
-
     @Specialization
     protected byte isNA(byte value, @SuppressWarnings("unused") boolean recursive) {
-        return doScalar(RRuntime.isNA(value));
+        return RRuntime.asLogical(RRuntime.isNA(value));
     }
 
     @Specialization
     protected byte isNA(int value, @SuppressWarnings("unused") boolean recursive) {
-        return doScalar(RRuntime.isNA(value));
+        return RRuntime.asLogical(RRuntime.isNA(value));
     }
 
     @Specialization
     protected byte isNA(double value, @SuppressWarnings("unused") boolean recursive) {
-        return doScalar(RRuntime.isNAorNaN(value));
+        return RRuntime.asLogical(RRuntime.isNAorNaN(value));
     }
 
     @Specialization
     protected byte isNA(RComplex value, @SuppressWarnings("unused") boolean recursive) {
-        return doScalar(RRuntime.isNA(value));
+        return RRuntime.asLogical(RRuntime.isNA(value));
     }
 
     @Specialization
     protected byte isNA(String value, @SuppressWarnings("unused") boolean recursive) {
-        return doScalar(RRuntime.isNA(value));
+        return RRuntime.asLogical(RRuntime.isNA(value));
     }
 
     @Specialization
     @SuppressWarnings("unused")
     protected byte isNA(RRaw value, boolean recursive) {
-        return doScalar(false);
+        return RRuntime.LOGICAL_FALSE;
     }
 
     @Specialization
     protected byte isNA(@SuppressWarnings("unused") RNull value, @SuppressWarnings("unused") boolean recursive) {
-        return doScalar(false);
-    }
-
-    @Specialization
-    protected byte isNA(RAbstractIntVector vector, @SuppressWarnings("unused") boolean recursive) {
-        return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i)));
-    }
-
-    @Specialization
-    protected byte isNA(RAbstractDoubleVector vector, @SuppressWarnings("unused") boolean recursive) {
-        // since
-        return doVector(vector, (v, i) -> naCheck.checkNAorNaN(v.getDataAt(i)));
+        return RRuntime.LOGICAL_FALSE;
     }
 
-    @Specialization
-    protected byte isNA(RAbstractComplexVector vector, @SuppressWarnings("unused") boolean recursive) {
-        return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i)));
+    @Specialization(guards = "xAccess.supports(x)")
+    protected byte anyNACached(RAbstractAtomicVector x, @SuppressWarnings("unused") boolean recursive,
+                    @Cached("x.access()") VectorAccess xAccess) {
+        switch (xAccess.getType()) {
+            case Logical:
+            case Integer:
+            case Character:
+                // shortcut when we know there's no NAs
+                if (!x.isComplete()) {
+                    try (SequentialIterator iter = xAccess.access(x)) {
+                        while (xAccess.next(iter)) {
+                            if (xAccess.isNA(iter)) {
+                                return RRuntime.LOGICAL_TRUE;
+                            }
+                        }
+                    }
+                }
+                break;
+            case Raw:
+                return RRuntime.LOGICAL_FALSE;
+            case Double:
+                // we need to check for NaNs
+                try (SequentialIterator iter = xAccess.access(x)) {
+                    while (xAccess.next(iter)) {
+                        if (xAccess.na.checkNAorNaN(xAccess.getDouble(iter))) {
+                            return RRuntime.LOGICAL_TRUE;
+                        }
+                    }
+                }
+                break;
+            case Complex:
+                // we need to check for NaNs
+                try (SequentialIterator iter = xAccess.access(x)) {
+                    while (xAccess.next(iter)) {
+                        if (xAccess.na.checkNAorNaN(xAccess.getComplexR(iter)) || xAccess.na.checkNAorNaN(xAccess.getComplexR(iter))) {
+                            return RRuntime.LOGICAL_TRUE;
+                        }
+                    }
+                }
+                break;
+        }
+        return RRuntime.LOGICAL_FALSE;
     }
 
-    @Specialization
-    protected byte isNA(RAbstractStringVector vector, @SuppressWarnings("unused") boolean recursive) {
-        return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i)));
+    @Specialization(replaces = "anyNACached")
+    protected byte anyNAGeneric(RAbstractAtomicVector x, boolean recursive) {
+        return anyNACached(x, recursive, x.slowPathAccess());
     }
 
-    @Specialization
-    protected byte isNA(RAbstractLogicalVector vector, @SuppressWarnings("unused") boolean recursive) {
-        return doVector(vector, (v, i) -> naCheck.check(v.getDataAt(i)));
+    protected AnyNA createRecursive() {
+        return AnyNANodeGen.create(true);
     }
 
-    @Specialization
-    protected byte isNA(@SuppressWarnings("unused") RAbstractRawVector vector, @SuppressWarnings("unused") boolean recursive) {
-        return doScalar(false);
+    @Specialization(guards = {"isRecursive", "recursive == cachedRecursive"})
+    protected byte isNARecursive(RList list, boolean recursive,
+                    @Cached("recursive") boolean cachedRecursive,
+                    @Cached("createClassProfile()") ValueProfile elementProfile,
+                    @Cached("create()") RLengthNode length) {
+        if (cachedRecursive) {
+            for (int i = 0; i < list.getLength(); i++) {
+                Object value = elementProfile.profile(list.getDataAt(i));
+                if (length.executeInteger(value) > 0) {
+                    if (recursive(recursive, value) == RRuntime.LOGICAL_TRUE) {
+                        return RRuntime.LOGICAL_TRUE;
+                    }
+                }
+            }
+        }
+        return RRuntime.LOGICAL_FALSE;
     }
 
-    protected AnyNA createRecursive() {
-        return AnyNANodeGen.create();
+    @TruffleBoundary
+    private byte recursive(boolean recursive, Object value) {
+        return execute(value, recursive);
     }
 
-    @Specialization(guards = "recursive")
+    @Specialization(guards = {"!isRecursive", "recursive == cachedRecursive"})
     protected byte isNA(RList list, boolean recursive,
+                    @Cached("recursive") boolean cachedRecursive,
                     @Cached("createRecursive()") AnyNA recursiveNode,
                     @Cached("createClassProfile()") ValueProfile elementProfile,
                     @Cached("create()") RLengthNode length) {
-
         for (int i = 0; i < list.getLength(); i++) {
             Object value = elementProfile.profile(list.getDataAt(i));
-            if (length.executeInteger(value) > 0) {
-                byte result = recursiveNode.execute(value, recursive);
-                if (result == RRuntime.LOGICAL_TRUE) {
+            if (cachedRecursive || length.executeInteger(value) == 1) {
+                if (recursiveNode.execute(value, recursive) == RRuntime.LOGICAL_TRUE) {
                     return RRuntime.LOGICAL_TRUE;
                 }
             }
         }
         return RRuntime.LOGICAL_FALSE;
     }
-
-    @Specialization(guards = "!recursive")
-    @SuppressWarnings("unused")
-    protected byte isNA(RList list, boolean recursive) {
-        return RRuntime.LOGICAL_FALSE;
-    }
 }
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
index 8aef4723e6..a2c4cbdf10 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
@@ -3221,6 +3221,38 @@ In anyDuplicated.default(c(1L, 2L, 1L, 1L, 3L, 2L), incomparables = "cat") :
 #anyNA(c(1, NA, 3), recursive = TRUE)
 [1] TRUE
 
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(1, NA))
+[1] TRUE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(1, NA), recursive = TRUE)
+[1] TRUE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = NA))
+[1] TRUE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = NA), recursive = TRUE)
+[1] TRUE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = NA, b = 'a'))
+[1] TRUE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = NA, b = 'a'), recursive = TRUE)
+[1] TRUE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = c('asdf', NA), b = 'a'))
+[1] FALSE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = c('asdf', NA), b = 'a'), recursive = TRUE)
+[1] TRUE
+
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
 #anyNA(list(a = c(1, 2, 3), b = 'a'))
 [1] FALSE
@@ -3229,6 +3261,10 @@ In anyDuplicated.default(c(1L, 2L, 1L, 1L, 3L, 2L), incomparables = "cat") :
 #anyNA(list(a = c(1, 2, 3), b = 'a'), recursive = TRUE)
 [1] FALSE
 
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = c(1, 2, 3), b = list(NA, 'a')), recursive = TRUE)
+[1] TRUE
+
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
 #anyNA(list(a = c(1, NA, 3), b = 'a'))
 [1] FALSE
@@ -3237,6 +3273,14 @@ In anyDuplicated.default(c(1L, 2L, 1L, 1L, 3L, 2L), incomparables = "cat") :
 #anyNA(list(a = c(1, NA, 3), b = 'a'), recursive = TRUE)
 [1] TRUE
 
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = c(NA, 3), b = 'a'))
+[1] FALSE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_anyNA.testanyNA3#
+#anyNA(list(a = c(NA, 3), b = 'a'), recursive = TRUE)
+[1] TRUE
+
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_aperm.testAperm#
 #{ a = array(1:24,c(2,3,4)); b = aperm(a); c(dim(b)[1],dim(b)[2],dim(b)[3]) }
 [1] 4 3 2
@@ -9955,6 +9999,9 @@ Error in attributes(x) <- 44 : attributes must be a list or NULL
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_attributesassign.testArgsCasts#
 #x <- 42;  attributes(x) <- NULL
 
+##com.oracle.truffle.r.test.builtins.TestBuiltin_attributesassign.testArgsCasts#
+#x <- 42;  attributes(x) <- list()
+
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_attributesassign.testattributesassign1#Ignored.ImplementationError#
 #argv <- list(NULL, NULL);`attributes<-`(argv[[1]],argv[[2]]);
 NULL
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java
index faeee96e06..a428859a15 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_anyNA.java
@@ -38,7 +38,18 @@ public class TestBuiltin_anyNA extends TestBase {
         assertEval("anyNA(c(1, NA, 3), recursive = TRUE)");
         assertEval("anyNA(list(a = c(1, 2, 3), b = 'a'))");
         assertEval("anyNA(list(a = c(1, NA, 3), b = 'a'))");
+        assertEval("anyNA(list(a = c('asdf', NA), b = 'a'))");
+        assertEval("anyNA(list(a = c(NA, 3), b = 'a'))");
+        assertEval("anyNA(list(a = NA, b = 'a'))");
+        assertEval("anyNA(list(a = NA))");
+        assertEval("anyNA(list(1, NA))");
+        assertEval("anyNA(list(a = c('asdf', NA), b = 'a'), recursive = TRUE)");
+        assertEval("anyNA(list(a = c(NA, 3), b = 'a'), recursive = TRUE)");
+        assertEval("anyNA(list(a = NA, b = 'a'), recursive = TRUE)");
+        assertEval("anyNA(list(a = NA), recursive = TRUE)");
+        assertEval("anyNA(list(1, NA), recursive = TRUE)");
         assertEval("anyNA(list(a = c(1, 2, 3), b = 'a'), recursive = TRUE)");
         assertEval("anyNA(list(a = c(1, NA, 3), b = 'a'), recursive = TRUE)");
+        assertEval("anyNA(list(a = c(1, 2, 3), b = list(NA, 'a')), recursive = TRUE)");
     }
 }
-- 
GitLab