diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java index 8a5446c3d6af500002450e0b5686860bd2c4f872..f309975878a3af3f16c66c8e2a099601b24bb498 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java @@ -35,9 +35,11 @@ import com.oracle.truffle.r.runtime.Utils; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RList; +import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; @@ -47,6 +49,9 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; * * TODO Can we find a way to efficiently write the specializations as generics? The code is * identical except for the argument type. + * + * TODO: GNU R preserves the corresponding values of names attribute. There are (ignored) tests for + * this in TestBuiltin_split. */ @RBuiltin(name = "split", kind = INTERNAL, parameterNames = {"x", "f"}, behavior = PURE) public abstract class Split extends RBuiltinNode.Arg2 { @@ -62,9 +67,37 @@ public abstract class Split extends RBuiltinNode.Arg2 { Casts.noCasts(Split.class); } - public static class SplitTemplate { - @SuppressWarnings("unused") private int[] collectResultsSize; - @SuppressWarnings("unused") private int nLevels; + @Specialization + protected RList split(RAbstractListVector x, RAbstractIntVector f) { + int[] factor = f.materialize().getDataWithoutCopying(); + RStringVector names = getLevelNode.execute(f); + final int nLevels = getNLevels(names); + + // initialise result arrays + Object[][] collectResults = new Object[nLevels][]; + int[] collectResultSize = new int[nLevels]; + for (int i = 0; i < collectResults.length; i++) { + collectResults[i] = new Object[INITIAL_SIZE]; + } + + // perform split + for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factor.length)) { + int resultIndex = factor[fi] - 1; // a factor is a 1-based int vector + Object[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + collect = collectResults[resultIndex]; + } + collect[collectResultSize[resultIndex]++] = x.getDataAt(i); + } + + // assemble result vectors and level names + Object[] results = new Object[nLevels]; + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createList(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i])); + } + + return RDataFactory.createList(results, names); } @Specialization @@ -199,6 +232,39 @@ public abstract class Split extends RBuiltinNode.Arg2 { return RDataFactory.createList(results, names); } + @Specialization + protected RList split(RRawVector x, RAbstractIntVector f) { + int[] factor = f.materialize().getDataWithoutCopying(); + RStringVector names = getLevelNode.execute(f); + final int nLevels = getNLevels(names); + + // initialise result arrays + byte[][] collectResults = new byte[nLevels][]; + int[] collectResultSize = new int[nLevels]; + for (int i = 0; i < collectResults.length; i++) { + collectResults[i] = new byte[INITIAL_SIZE]; + } + + // perform split + for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factor.length)) { + int resultIndex = factor[fi] - 1; // a factor is a 1-based int vector + byte[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + collect = collectResults[resultIndex]; + } + collect[collectResultSize[resultIndex]++] = x.getDataAt(i).getValue(); + } + + // assemble result vectors and level names + Object[] results = new Object[nLevels]; + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createRawVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i])); + } + + return RDataFactory.createList(results, names); + } + private static int getNLevels(RStringVector levels) { return levels != null ? levels.getLength() : 0; } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index c03ce21ab0cb9070e7d00e468c82274b27f1c870..982b647f335990c04417558dc2d95376e77ca2bb 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -64273,6 +64273,14 @@ $`2` [1] 2 4 6 8 10 +##com.oracle.truffle.r.test.builtins.TestBuiltin_split.testSplit# +#{ split(as.raw(1:10), as.factor(c('a', 'b', 'a')); } +Error: unexpected ';' in "{ split(as.raw(1:10), as.factor(c('a', 'b', 'a'));" + +##com.oracle.truffle.r.test.builtins.TestBuiltin_split.testSplit# +#{ split(list(1, 2L, 'x', T), as.factor(c('a', 'b', 'a')); } +Error: unexpected ';' in "{ split(list(1, 2L, 'x', T), as.factor(c('a', 'b', 'a'));" + ##com.oracle.truffle.r.test.builtins.TestBuiltin_split.testSplit# #{ x <- factor(c("a", "b", "a")); attr(x, "levels")<-c(7L, 42L) ; split(1:3, x) } $`7` @@ -64282,6 +64290,24 @@ $`42` [1] 2 +##com.oracle.truffle.r.test.builtins.TestBuiltin_split.testSplitWithNames#Ignored.Unimplemented# +#{ split(list(q=1, w=2L, e='x', r=T), as.factor(c('a', 'b', 'a')); } +Error: unexpected ';' in "{ split(list(q=1, w=2L, e='x', r=T), as.factor(c('a', 'b', 'a'));" + +##com.oracle.truffle.r.test.builtins.TestBuiltin_split.testSplitWithNames#Ignored.Unimplemented# +#{ tmp <- c(1,2,3); names(tmp) <- c('x','y','z'); split(tmp, as.factor(c('a','b'))); } +$a +x z +1 3 + +$b +y +2 + +Warning message: +In split.default(tmp, as.factor(c("a", "b"))) : + data length is not a multiple of split variable + ##com.oracle.truffle.r.test.builtins.TestBuiltin_split.testsplit1# #argv <- list(1:6, structure(1:2, .Label = c('1', '2'), class = 'factor')); .Internal(split(argv[[1]], argv[[2]])) $`1` diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_split.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_split.java index 234bd45d0266926805884872852794e57ed378de..d551d2225e677fb93c52e60517b814c17b2d78ff 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_split.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_split.java @@ -116,5 +116,13 @@ public class TestBuiltin_split extends TestBase { assertEval("{ fu <- c(\"a\",\"b\") ; split(1:8,fu) }"); assertEval("{ g <- factor(round(c(0.4,1.3,0.6,1.8,2.5,4.1,2.2,1.0))) ; x <- c(0.1,3.2,1,0.6,1.9,3.3,1.6,1.7) + sqrt(as.numeric(g)) ; xg <- split(x, g) ; xg }"); assertEval("{ x <- factor(c(\"a\", \"b\", \"a\")); attr(x, \"levels\")<-c(7L, 42L) ; split(1:3, x) }"); + assertEval("{ split(list(1, 2L, 'x', T), as.factor(c('a', 'b', 'a')); }"); + assertEval("{ split(as.raw(1:10), as.factor(c('a', 'b', 'a')); }"); + } + + @Test + public void testSplitWithNames() { + assertEval(Ignored.Unimplemented, "{ split(list(q=1, w=2L, e='x', r=T), as.factor(c('a', 'b', 'a')); }"); + assertEval(Ignored.Unimplemented, "{ tmp <- c(1,2,3); names(tmp) <- c('x','y','z'); split(tmp, as.factor(c('a','b'))); }"); } }