diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java index 39a1d3193eaa36c94d4ff413281aa432c2d133d8..bb4e2eba824e3aee8169d3ef753bd444fdcc8633 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java @@ -14,6 +14,7 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.size; import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; +import static com.oracle.truffle.r.runtime.RError.NO_CALLER; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; @@ -22,10 +23,13 @@ import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.TypeSystemReference; +import com.oracle.truffle.api.frame.FrameSlot; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.access.FrameSlotNode; +import com.oracle.truffle.r.nodes.access.variables.LocalReadVariableNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetClassAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; @@ -36,11 +40,17 @@ import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.IsNumericNode import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen; import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen.IsIntegralNumericNodeGen; import com.oracle.truffle.r.nodes.control.RLengthNode; -import com.oracle.truffle.r.nodes.control.RLengthNodeGen; import com.oracle.truffle.r.nodes.ffi.AsRealNode; import com.oracle.truffle.r.nodes.ffi.AsRealNodeGen; import com.oracle.truffle.r.nodes.function.CallMatcherNode.CallMatcherGenericNode; +import com.oracle.truffle.r.nodes.function.ClassHierarchyNode; +import com.oracle.truffle.r.nodes.function.RCallBaseNode; +import com.oracle.truffle.r.nodes.function.RCallNode; +import com.oracle.truffle.r.nodes.unary.CastIntegerNode; +import com.oracle.truffle.r.nodes.unary.FindFirstNode; +import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; @@ -55,10 +65,12 @@ import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RSequence; +import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RTypesFlatLayout; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.env.REnvironment; import com.oracle.truffle.r.runtime.nodes.RFastPathNode; /** @@ -81,7 +93,7 @@ import com.oracle.truffle.r.runtime.nodes.RFastPathNode; * {@link SeqInt} node for the fast paths. * */ -public class SeqFunctions { +public final class SeqFunctions { public abstract static class FastPathAdapter extends RFastPathNode { public static IsMissingOrNumericNode createIsMissingOrNumericNode() { @@ -246,7 +258,7 @@ public class SeqFunctions { public static Object[] reorderedArguments(RArgsValuesAndNames argsIn, RFunction seqIntFunction) { RArgsValuesAndNames args = argsIn; if (args.getSignature().getNonNullCount() != 0) { - return CallMatcherGenericNode.reorderArguments(args.getArguments(), seqIntFunction, args.getSignature(), RError.NO_CALLER).getArguments(); + return CallMatcherGenericNode.reorderArguments(args.getArguments(), seqIntFunction, args.getSignature(), NO_CALLER).getArguments(); } else { int len = argsIn.getLength(); Object[] xArgs = new Object[5]; @@ -321,13 +333,62 @@ public class SeqFunctions { @TypeSystemReference(RTypesFlatLayout.class) @RBuiltin(name = "seq_along", kind = PRIMITIVE, parameterNames = {"along.with"}, behavior = PURE) public abstract static class SeqAlong extends RBuiltinNode { + @Child private ClassHierarchyNode classHierarchyNode = ClassHierarchyNode.create(); - @Child private RLengthNode length = RLengthNodeGen.create(); - - @Specialization - protected RIntSequence seq(VirtualFrame frame, Object value) { + @Specialization(guards = "!hasClass(value)") + protected RIntSequence seq(VirtualFrame frame, Object value, + @Cached("create()") RLengthNode length) { return RDataFactory.createIntSequence(1, 1, length.executeInteger(frame, value)); } + + @Specialization(guards = "hasClass(value)") + protected RIntSequence seq(VirtualFrame frame, Object value, + @Cached("create()") LengthDispatcher dispatcher) { + return RDataFactory.createIntSequence(1, 1, dispatcher.execute(frame, value)); + } + + boolean hasClass(Object obj) { + final RStringVector classVec = classHierarchyNode.execute(obj); + return classVec != null && classVec.getLength() != 0; + } + + /** + * Invokes the 'length' function, which may dispatch to some other function than default + * length depending on the class of the argument. + */ + static final class LengthDispatcher extends Node { + private final Object argsIdentifier = new Object(); + private final BranchProfile errorProfile = BranchProfile.create(); + @Child private RCallBaseNode call = RCallNode.createExplicitCall(argsIdentifier); + @Child private FrameSlotNode argumentsSlot = FrameSlotNode.createTemp(argsIdentifier, true); + @Child private LocalReadVariableNode readLength = LocalReadVariableNode.create("length", true); + @Child private CastIntegerNode castInteger = CastIntegerNode.createNonPreserving(); + @Child private FindFirstNode findFirst = FindFirstNode.create(Integer.class, NO_CALLER, Message.NEGATIVE_LENGTH_VECTORS_NOT_ALLOWED); + + public static LengthDispatcher create() { + return new LengthDispatcher(); + } + + public int execute(VirtualFrame frame, Object target) { + FrameSlot argsFrameSlot = argumentsSlot.executeFrameSlot(frame); + try { + frame.setObject(argsFrameSlot, new RArgsValuesAndNames(new Object[]{target}, ArgumentsSignature.empty(1))); + Object lengthFunction = readLength.execute(frame, REnvironment.baseEnv().getFrame()); + int result = castResult(call.execute(frame, lengthFunction)); + if (result < 0 || RRuntime.isNA(result)) { + errorProfile.enter(); + throw RError.error(NO_CALLER, Message.NEGATIVE_LENGTH_VECTORS_NOT_ALLOWED); + } + return result; + } finally { + frame.setObject(argsFrameSlot, null); + } + } + + private int castResult(Object result) { + return (Integer) findFirst.execute(castInteger.execute(result)); + } + } } @TypeSystemReference(RTypesFlatLayout.class) diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/RLengthNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/RLengthNode.java index 34bd86da77f53ca5be32f45c1e0f0aa872bf55d4..14bbeb5c13d9859a7e8421387c9dd5ec6a2e7421 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/RLengthNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/control/RLengthNode.java @@ -41,6 +41,11 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; import com.oracle.truffle.r.runtime.env.REnvironment; import com.oracle.truffle.r.runtime.nodes.RNode; +/** + * Gets length of given container. Does not actually dispatch to the 'length' function, which may be + * overridden for some S3/S4 classes. Check if you need to get actual length, or what the 'length' + * function returns, like in {@code seq_along}. + */ @NodeChild("operand") public abstract class RLengthNode extends RNode { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FindFirstNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FindFirstNode.java index e5c9f3cc86ee783b454cfc857687907d81609d74..2aac18c22c465e9b9b6c91c19ce0699d909efbda 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FindFirstNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FindFirstNode.java @@ -52,6 +52,10 @@ public abstract class FindFirstNode extends CastNode { this(elementClass, null, null, null, defaultValue); } + public static FindFirstNode create(Class<?> elementClass, RBaseNode callObj, RError.Message message, Object... args) { + return FindFirstNodeGen.create(elementClass, callObj, message, args, null); + } + public Class<?> getElementClass() { return elementClass; } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index f17543e5fe71799d5a31af8e1a7c21922d2e7ca1..512d438c1f484e3c9c335261017e16a1a528a6e3 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -57323,6 +57323,44 @@ integer(0) [41] 26163.27 26367.35 26571.43 26775.51 26979.59 27183.67 27387.76 27591.84 [49] 27795.92 28000.00 +##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testWithNonStandardLength#Ignored.Unimplemented# +#{ assign('length.myclass', function(...) 42, envir=.__S3MethodsTable__.); x <- 1; class(x) <- 'myclass'; seq_along(x); } + [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 +[26] 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testWithNonStandardLength# +#{ length <- function(x) 42; seq_along(c(1,2,3)) } +[1] 1 2 3 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testWithNonStandardLength# +#{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) '48'; seq_along(x) } + [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 +[26] 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testWithNonStandardLength#Output.IgnoreWarningContext# +#{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) 'hello world'; seq_along(x) } +Error: negative length vectors are not allowed +In addition: Warning message: +NAs introduced by coercion + +##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testWithNonStandardLength# +#{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) 42; seq_along(x) } + [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 +[26] 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testWithNonStandardLength# +#{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) c(100, 200); seq_along(x) } + [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 + [19] 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 + [37] 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 + [55] 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 + [73] 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 + [91] 91 92 93 94 95 96 97 98 99 100 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testWithNonStandardLength# +#{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) numeric(0); seq_along(x) } +Error: negative length vectors are not allowed + ##com.oracle.truffle.r.test.builtins.TestBuiltin_seq_along.testseq1# #argv <- list(c('y', 'A', 'U', 'V'));do.call('seq_along', argv); [1] 1 2 3 4 diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_seq_along.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_seq_along.java index 05336f48d283b2d79a9cf35e956063adb96586e7..8cdf55201a05931ff691cc70eb98464e25746128 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_seq_along.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_seq_along.java @@ -131,4 +131,17 @@ public class TestBuiltin_seq_along extends TestBase { assertEval("argv <- list(structure(list(num = 1:4, fac = structure(11:14, .Label = c('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o'), class = 'factor'), date = structure(c(15065, 15066, 15067, 15068), class = 'Date'), pv = structure(list(1:3, 4:5, 6:7, 8:10), class = c('package_version', 'numeric_version'))), .Names = c('num', 'fac', 'date', 'pv'), row.names = c(NA, -4L), class = 'data.frame'));" + "do.call('seq_along', argv)"); } + + @Test + public void testWithNonStandardLength() { + assertEval("{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) 42; seq_along(x) }"); + assertEval("{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) c(100, 200); seq_along(x) }"); + assertEval("{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) '48'; seq_along(x) }"); + assertEval("{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) numeric(0); seq_along(x) }"); + assertEval(Output.IgnoreWarningContext, "{ x <- c(1,2,3); class(x) <- 'myclass'; length.myclass <- function(w) 'hello world'; seq_along(x) }"); + // length defined in global env should not get us confused: + assertEval("{ length <- function(x) 42; seq_along(c(1,2,3)) }"); + // length in __S3MethodsTable__ should work too, N.B.: needs complete S3 dispatch support + assertEval(Ignored.Unimplemented, "{ assign('length.myclass', function(...) 42, envir=.__S3MethodsTable__.); x <- 1; class(x) <- 'myclass'; seq_along(x); }"); + } }