diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Combine.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Combine.java index 2673ed6c34d8d4275321eaf956bf17b1710a31ba..73afa80f8a426445181447680e2e8bc93c66e1e3 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Combine.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Combine.java @@ -22,6 +22,7 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; import static com.oracle.truffle.r.nodes.unary.PrecedenceNode.COMPLEX_PRECEDENCE; import static com.oracle.truffle.r.nodes.unary.PrecedenceNode.DOUBLE_PRECEDENCE; import static com.oracle.truffle.r.nodes.unary.PrecedenceNode.EXPRESSION_PRECEDENCE; @@ -46,6 +47,7 @@ import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ValueProfile; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.base.CombineNodeGen.CombineInputCastNodeGen; import com.oracle.truffle.r.nodes.unary.CastComplexNodeGen; @@ -76,7 +78,7 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.nodes.RNode; import com.oracle.truffle.r.runtime.ops.na.NACheck; -@RBuiltin(name = "c", kind = PRIMITIVE, parameterNames = {"..."}, dispatch = INTERNAL_GENERIC, behavior = PURE) +@RBuiltin(name = "c", kind = PRIMITIVE, parameterNames = {"...", "recursive"}, dispatch = INTERNAL_GENERIC, behavior = PURE) public abstract class Combine extends RBuiltinNode { public static Combine create() { @@ -101,19 +103,24 @@ public abstract class Combine extends RBuiltinNode { private final ConditionProfile hasNewNamesProfile = ConditionProfile.createBinaryProfile(); @CompilationFinal private final ValueProfile[] argProfiles = new ValueProfile[MAX_PROFILES]; - public abstract Object executeCombine(Object value); + @Override + protected void createCasts(CastBuilder casts) { + casts.arg("recursive").asLogicalVector().findFirst(RRuntime.LOGICAL_NA).map(toBoolean()); + } + + public abstract Object executeCombine(Object value, Object recursive); protected boolean isSimpleArguments(RArgsValuesAndNames args) { return !signatureHasNames(args.getSignature()) && args.getLength() == 1 && !(args.getArgument(0) instanceof RAbstractVector); } - @Specialization(guards = "isSimpleArguments(args)") - protected Object combineSimple(RArgsValuesAndNames args) { + @Specialization(guards = {"isSimpleArguments(args)", "!recursive"}) + protected Object combineSimple(RArgsValuesAndNames args, @SuppressWarnings("unused") boolean recursive) { return args.getArgument(0); } - @Specialization(contains = "combineSimple", limit = "1", guards = {"args.getSignature() == cachedSignature", "cachedPrecedence == precedence(args, cachedSignature.getLength())"}) - protected Object combineCached(RArgsValuesAndNames args, // + @Specialization(contains = "combineSimple", limit = "1", guards = {"!recursive", "args.getSignature() == cachedSignature", "cachedPrecedence == precedence(args, cachedSignature.getLength())"}) + protected Object combineCached(RArgsValuesAndNames args, @SuppressWarnings("unused") boolean recursive, // @Cached("args.getSignature()") ArgumentsSignature cachedSignature, // @Cached("precedence( args, cachedSignature.getLength())") int cachedPrecedence, // @Cached("createCast(cachedPrecedence)") CastNode cast, // @@ -144,6 +151,50 @@ public abstract class Combine extends RBuiltinNode { return result; } + @TruffleBoundary + @Specialization(limit = "COMBINE_CACHED_LIMIT", contains = "combineCached", guards = {"!recursive", "cachedPrecedence == precedence(args)"}) + protected Object combine(RArgsValuesAndNames args, @SuppressWarnings("unused") boolean recursive, // + @Cached("precedence(args, args.getLength())") int cachedPrecedence, // + @Cached("createCast(cachedPrecedence)") CastNode cast, // + @Cached("create()") BranchProfile naNameBranch, // + @Cached("create()") NACheck naNameCheck, // + @Cached("createBinaryProfile()") ConditionProfile hasNamesProfile) { + return combineCached(args, false, args.getSignature(), cachedPrecedence, cast, naNameBranch, naNameCheck, hasNamesProfile); + } + + @Specialization(guards = "recursive") + protected Object combineRecursive(RArgsValuesAndNames args, @SuppressWarnings("unused") boolean recursive, + @Cached("create()") Combine recursiveCombine, // + @Cached("createBinaryProfile()") ConditionProfile useNewArgsProfile) { + return combineRecursive(args, recursiveCombine, useNewArgsProfile); + } + + @SuppressWarnings("static-method") + @ExplodeLoop + private Object combineRecursive(RArgsValuesAndNames args, Combine recursiveCombine, ConditionProfile useNewArgsProfile) { + Object[] argsArray = args.getArguments(); + Object[] newArgsArray = new Object[argsArray.length]; + boolean useNewArgs = false; + for (int i = 0; i < argsArray.length; i++) { + Object arg = argsArray[i]; + if (arg instanceof RList) { + Object[] argsFromList = ((RList) arg).getDataWithoutCopying(); + newArgsArray[i] = recursiveCombine.executeCombine(new RArgsValuesAndNames(argsFromList, + ArgumentsSignature.empty(argsFromList.length)), true); + useNewArgs = true; + } else { + newArgsArray[i] = arg; + } + } + + if (useNewArgsProfile.profile(useNewArgs)) { + return recursiveCombine.executeCombine(new RArgsValuesAndNames(newArgsArray, + args.getSignature()), false); + } else { + return recursiveCombine.executeCombine(args, false); + } + } + @ExplodeLoop private int prepareElements(Object[] args, CastNode cast, int precedence, Object[] elements) { int size = 0; @@ -278,21 +329,9 @@ public abstract class Combine extends RBuiltinNode { return signature != null && signature.getNonNullCount() > 0; } - @TruffleBoundary - @Specialization(limit = "COMBINE_CACHED_LIMIT", contains = "combineCached", guards = "cachedPrecedence == precedence(args)") - protected Object combine(RArgsValuesAndNames args, // - @Cached("precedence(args, args.getLength())") int cachedPrecedence, // - @Cached("createCast(cachedPrecedence)") CastNode cast, // - @Cached("create()") BranchProfile naNameBranch, // - @Cached("create()") NACheck naNameCheck, // - @Cached("createBinaryProfile()") ConditionProfile hasNamesProfile) { - return combineCached(args, args.getSignature(), cachedPrecedence, cast, naNameBranch, naNameCheck, hasNamesProfile); - } - @Specialization(guards = "!isArguments(args)") - protected Object nonArguments(Object args, - @Cached("create()") Combine combine) { - return combine.executeCombine(new RArgsValuesAndNames(new Object[]{args}, EMPTY_SIGNATURE)); + protected Object nonArguments(Object args, boolean recursive, @Cached("create()") Combine combine) { + return combine.executeCombine(new RArgsValuesAndNames(new Object[]{args}, EMPTY_SIGNATURE), recursive); } private Object readAndCast(CastNode castNode, Object arg, int precedence) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Max.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Max.java index dac7f3b6519043254b51eaf7301031457480b257..77bf214173a7afaccab6774615aa496602e993ac 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Max.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Max.java @@ -66,6 +66,6 @@ public abstract class Max extends RBuiltinNode { @Specialization(contains = "maxLengthOne") protected Object max(RArgsValuesAndNames args, boolean naRm, // @Cached("create()") Combine combine) { - return reduce.executeReduce(combine.executeCombine(args), naRm, false); + return reduce.executeReduce(combine.executeCombine(args, false), naRm, false); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Min.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Min.java index b29a667281d58fef36227aa7ca87672db093c4af..3da49f67fdc8e361b2150ff9c6997523c2e23d20 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Min.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Min.java @@ -66,6 +66,6 @@ public abstract class Min extends RBuiltinNode { @Specialization(contains = "minLengthOne") protected Object min(RArgsValuesAndNames args, boolean naRm, // @Cached("create()") Combine combine) { - return reduce.executeReduce(combine.executeCombine(args), naRm, false); + return reduce.executeReduce(combine.executeCombine(args, false), naRm, false); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Range.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Range.java index 8176fec9cefbc7e714de3fe4a58a821c20841f55..5c2486f07d81052ab6848c023ba071e89b6cfe56 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Range.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Range.java @@ -82,7 +82,7 @@ public abstract class Range extends RBuiltinNode { @Specialization(contains = "rangeLengthOne") protected RVector<?> range(RArgsValuesAndNames args, boolean naRm, boolean finite, // @Cached("create()") Combine combine) { - Object combined = combine.executeCombine(args); + Object combined = combine.executeCombine(args, false); Object min = minReduce.executeReduce(combined, naRm, finite); Object max = maxReduce.executeReduce(combined, naRm, finite); return createResult(min, max); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java index b356056c0a2f2ea07c6f8653f6728bbcce4b6116..bee3a6cf603cf725dcafa9ca6ded51f5fd1ac85c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java @@ -110,6 +110,6 @@ public abstract class Sum extends RBuiltinNode { @Specialization(contains = {"sumLengthOneRDoubleVector", "sumLengthOne"}) protected Object sum(RArgsValuesAndNames args, boolean naRm, // @Cached("create()") Combine combine) { - return reduce.executeReduce(combine.executeCombine(args), naRm, false); + return reduce.executeReduce(combine.executeCombine(args, false), naRm, false); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index a65d85bf1233c9fa53e3a2f8ac4fcb72d77df7a1..4a8f8ffe6e46a0ad315fa4ce911cf4a0e449842d 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -10364,6 +10364,18 @@ expression(1, 2) #{ c(1i,0/0) } [1] 0+1i NaN+0i +##com.oracle.truffle.r.test.builtins.TestBuiltin_c.testRecursive +#argv <- list(c(1,2),c(3,4),c(5,6), recursive=TRUE));c(argv[[1]]); +Error: unexpected ')' in "argv <- list(c(1,2),c(3,4),c(5,6), recursive=TRUE))" + +##com.oracle.truffle.r.test.builtins.TestBuiltin_c.testRecursive +#argv <- list(c(list(c(1,2),c(3,4)),c(5,6), recursive=TRUE));c(argv[[1]]); +[1] 1 2 3 4 5 6 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_c.testRecursive +#argv <- list(list(), recursive=TRUE));c(argv[[1]]); +Error: unexpected ')' in "argv <- list(list(), recursive=TRUE))" + ##com.oracle.truffle.r.test.builtins.TestBuiltin_c.testc1 #argv <- list(character(0), 'myLib/myTst');c(argv[[1]],argv[[2]]); [1] "myLib/myTst" diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_c.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_c.java index 03f127743633fcfd51b15cf67971db56ae270f5c..223602182781ec7fd20a324d37139797968398f4 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_c.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_c.java @@ -533,6 +533,13 @@ public class TestBuiltin_c extends TestBase { assertEval("{ x<-c(a=42); y<-c(b=7); z<-c(x,y); w<-names(z); w[[1]]<-\"c\"; z }"); } + @Test + public void testRecursive() { + assertEval("argv <- list(c(list(c(1,2),c(3,4)),c(5,6), recursive=TRUE));c(argv[[1]]);"); + assertEval("argv <- list(c(1,2),c(3,4),c(5,6), recursive=TRUE));c(argv[[1]]);"); + assertEval("argv <- list(list(), recursive=TRUE));c(argv[[1]]);"); + } + @Test public void testCombineBroken() { assertEval(Ignored.Unknown, "{ c(1i,0/0) }"); // yes, this is done by GNU-R, note