diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java index e5261bd109bb11571b37481454a1b7fe477a7aaf..9ae60dcab66d4349a48bc580a2a8359125dbe5d2 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java @@ -90,6 +90,7 @@ import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltinKind; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.RIntSequence; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RMissing; @@ -587,6 +588,22 @@ public class CastBuilderTest { testPipeline(); } + @Test + public void testSampleNonNASequence() { + arg.notNA(RError.Message.GENERIC, "Error"); + RIntSequence seq = RDataFactory.createIntSequence(1, 1, 1); + Object res = cast(seq); + Assert.assertSame(seq, res); + } + + @Test + public void testSampleNAVector() { + arg.notNA("REPLACEMENT"); + RDoubleVector vec = RDataFactory.createDoubleVector(new double[]{0, 1, RRuntime.DOUBLE_NA, 3}, false); + Object res = cast(vec); + Assert.assertEquals("REPLACEMENT", res); + } + @Test public void testPreserveNonVectorFlag() { arg.allowNull().asVector(true); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/NonNANode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/NonNANode.java index 65a04a2b0b631c99411f704aa182b607485e0da1..cfeed6b5eae610a7cc77fdad07ae01d6b2bccd6e 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/NonNANode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/NonNANode.java @@ -22,13 +22,20 @@ */ package com.oracle.truffle.r.nodes.unary; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.nodes.RBaseNode; public abstract class NonNANode extends CastNode { @@ -158,8 +165,63 @@ public abstract class NonNANode extends CastNode { } @Specialization(guards = "!isComplete(x)") - protected Object onIncompleteContainer(RAbstractContainer x) { - return handleNA(x); + protected Object onPossiblyIncompleteContainer(RAbstractIntVector x) { + int len = x.getLength(); + for (int i = 0; i < len; i++) { + if (RRuntime.isNA(x.getDataAt(i))) { + return handleNA(x); + } + } + return x; + } + + @Specialization(guards = "!isComplete(x)") + protected Object onPossiblyIncompleteContainer(RAbstractLogicalVector x) { + int len = x.getLength(); + for (int i = 0; i < len; i++) { + if (RRuntime.isNA(x.getDataAt(i))) { + return handleNA(x); + } + } + return x; + } + + @Specialization(guards = "!isComplete(x)") + protected Object onPossiblyIncompleteContainer(RAbstractDoubleVector x) { + int len = x.getLength(); + for (int i = 0; i < len; i++) { + if (RRuntime.isNA(x.getDataAt(i))) { + return handleNA(x); + } + } + return x; + } + + @Specialization(guards = "!isComplete(x)") + protected Object onPossiblyIncompleteContainer(RAbstractComplexVector x) { + int len = x.getLength(); + for (int i = 0; i < len; i++) { + if (RRuntime.isNA(x.getDataAt(i))) { + return handleNA(x); + } + } + return x; + } + + @Specialization(guards = "!isComplete(x)") + protected Object onPossiblyIncompleteContainer(RAbstractStringVector x) { + int len = x.getLength(); + for (int i = 0; i < len; i++) { + if (RRuntime.isNA(x.getDataAt(i))) { + return handleNA(x); + } + } + return x; + } + + @Specialization(guards = "!isComplete(x)") + protected Object onPossiblyIncompleteContainer(RAbstractRawVector x) { + return x; } }