From c5a5a6238edd3024b49a83bdd6523284156ff4b8 Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Wed, 7 Feb 2018 12:05:58 +0100 Subject: [PATCH] Fix cbind to work with arrays with more than 2 dimensions --- .../truffle/r/nodes/builtin/base/Bind.java | 116 +++++++++++------- .../com/oracle/truffle/r/runtime/RError.java | 1 + .../truffle/r/test/ExpectedTestOutput.test | 70 +++++++++++ .../r/test/builtins/TestBuiltin_cbind.java | 12 +- 4 files changed, 157 insertions(+), 42 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java index 165ad928ea..c21f97201c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java @@ -64,6 +64,7 @@ import com.oracle.truffle.r.nodes.unary.PrecedenceNodeGen; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.RArguments; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.Utils; @@ -302,9 +303,8 @@ public abstract class Bind extends RBaseNode { if (!GetDimAttributeNode.isArray(dim) || dim.length == 1) { RStringVector names = extractNamesNode.execute(vec); firstDimNames = names == null ? RNull.instance : names; - } else { - RInternalError.unimplemented("binding multi-dimensional arrays is not supported"); } + // dimnames are simply ignored for arrays with 3 and more dimensions } if (firstDimNames != RNull.instance) { RAbstractStringVector names = (RAbstractStringVector) firstDimNames; @@ -365,43 +365,74 @@ public abstract class Bind extends RBaseNode { } } } else { - RInternalError.unimplemented("binding multi-dimensional arrays is not supported"); - return 0; + for (int i = 0; i < resDim; i++) { + dimNamesArray[ind++] = RRuntime.NAMES_ATTR_EMPTY_VALUE; + } + return -ind; } } } /** * @param vectors vectors to be combined - * @param res result dims - * @param bindDims columns dim (cbind) or rows dim (rbind) - * @return whether number of rows (cbind) or columns (rbind) in vectors is the same + * @param res (output) dimensions of the resulting matrix. + * @param bindDims (output, size == vectors.length) for each vector gives how many columns + * (cbind) it will fill in the resulting matrix, i.e. 1 for vectors and columns count + * for matrices. + * @return whether some vectors will have to be recycled, i.e. iterated more than once when + * filling the result */ protected boolean getResultDimensions(RAbstractVector[] vectors, int[] res, int[] bindDims) { + /* + * From the documentation: if there are several matrix arguments, they must all have the + * same number of columns (or rows) and this will be the number of columns (or rows) of the + * result. If all the arguments are vectors, the number of columns (rows) in the result is + * equal to the length of the longest vector. Values in shorter arguments are recycled to + * achieve this length (with a warning if they are recycled only fractionally) + * + * When the arguments consist of a mix of matrices and vectors the number of columns (rows) + * of the result is determined by the number of columns (rows) of the matrix arguments. Any + * vectors have their values recycled or subsetted to achieve this length. + */ + assert vectors.length > 0; + // NOTE: naming is chosen for cbind version, but it applies to rbind too + int rowsCountMatrix = -1; // the number of rows of the first matrix argument if any + int rowsCountVector = 0; // the max length of simple vectors + int minRowsCountVector = 0; // the min length of simple vectors, to determine return value + int[] rowsCountsVectors = new int[vectors.length]; // rows count of each vector/matrix + int columnsCount = 0; // the total number of columns (cbind) in the resulting matrix int srcDim1Ind = type == BindType.cbind ? 0 : 1; int srcDim2Ind = type == BindType.cbind ? 1 : 0; - assert vectors.length > 0; - RAbstractVector v = vectorProfile.profile(vectors[0]); - int[] dim = getDimensions(v, getVectorDimensions(v)); - assert dim.length == 2; - bindDims[0] = dim[srcDim2Ind]; - res[srcDim1Ind] = dim[srcDim1Ind]; - res[srcDim2Ind] = dim[srcDim2Ind]; - boolean notEqualDims = false; - for (int i = 1; i < vectors.length; i++) { - RAbstractVector v2 = vectorProfile.profile(vectors[i]); - int[] dims = getDimensions(v2, getVectorDimensions(v2)); - assert dims.length == 2; + for (int i = 0; i < vectors.length; i++) { + RAbstractVector v = vectorProfile.profile(vectors[i]); + int[] dims = getVectorDimensions(v); + if (dims == null || dims.length != 2) { + int vectorLen = v.getLength(); + rowsCountsVectors[i] = vectorLen; + rowsCountVector = Math.max(rowsCountVector, vectorLen); + minRowsCountVector = Math.min(minRowsCountVector, vectorLen); + columnsCount++; + bindDims[i] = 1; + continue; + } + columnsCount += dims[srcDim2Ind]; bindDims[i] = dims[srcDim2Ind]; - if (dims[srcDim1Ind] != res[srcDim1Ind]) { - notEqualDims = true; - if (dims[srcDim1Ind] > res[srcDim1Ind]) { - res[srcDim1Ind] = dims[srcDim1Ind]; - } + if (rowsCountMatrix == -1) { + rowsCountMatrix = dims[srcDim1Ind]; + + } else if (rowsCountMatrix != dims[srcDim1Ind]) { + error(type == BindType.cbind ? Message.ROWS_MUST_MATCH : Message.COLS_MUST_MATCH, i + 1); + } + } + res[srcDim2Ind] = columnsCount; + int resultRowsCount = res[srcDim1Ind] = rowsCountMatrix != -1 ? rowsCountMatrix : rowsCountVector; + for (int i = 0; i < vectors.length; i++) { + if (rowsCountsVectors[i] != 0 && rowsCountsVectors[i] > resultRowsCount) { + warning(type == BindType.cbind ? Message.ROWS_NOT_MULTIPLE : Message.COLUMNS_NOT_MULTIPLE, i + 1); + break; // reported only for first ocurrence } - res[srcDim2Ind] += dims[srcDim2Ind]; } - return notEqualDims; + return minRowsCountVector < resultRowsCount; } protected int[] getDimensions(RAbstractVector vector, int[] dimensions) { @@ -504,21 +535,24 @@ public abstract class Bind extends RBaseNode { } // compute result vector values - int vecLength = vec.getLength(); - for (int j = 0; j < vecLength; j++) { - result.transferElementSameType(ind++, vec, j); - } - if (rowsAndColumnsNotEqual) { - everSeenNotEqualRows.enter(); - if (vecLength < resultDimensions[0]) { - // re-use vector elements - int k = 0; - for (int j = 0; j < resultDimensions[0] - vecLength; j++, k = Utils.incMod(k, vecLength)) { - result.transferElementSameType(ind++, vectors[i], k); - } - - if (k != 0) { - RError.warning(this, RError.Message.ROWS_NOT_MULTIPLE, i + 1); + int[] dims = getDimensions(vec, getVectorDimensions(vec)); + assert dims.length == 2; + for (int col = 0; col < dims[1]; col++) { + int rowsCount = Math.min(dims[0], resultDimensions[0]); + for (int row = 0; row < rowsCount; row++) { + result.transferElementSameType(ind++, vec, dims[0] * col + row); + } + if (rowsAndColumnsNotEqual) { + everSeenNotEqualRows.enter(); + if (rowsCount < resultDimensions[0]) { + // re-use vector elements + int k = 0; + for (int j = 0; j < resultDimensions[0] - dims[0]; j++, k = Utils.incMod(k, dims[0])) { + result.transferElementSameType(ind++, vectors[i], dims[0] * col + k); + } + if (k != 0) { + RError.warning(this, RError.Message.ROWS_NOT_MULTIPLE, i + 1); + } } } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java index 529b416957..2226824b9f 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java @@ -589,6 +589,7 @@ public final class RError extends RuntimeException implements TruffleException { IS_NULL("'%s' is NULL"), MUST_BE_SCALAR("'%s' must be of length 1"), ROWS_MUST_MATCH("number of rows of matrices must match (see arg %d)"), + COLS_MUST_MATCH("number of columns of matrices must match (see arg %d)"), ROWS_NOT_MULTIPLE("number of rows of result is not a multiple of vector length (arg %d)"), ARG_ONE_OF("'%s' should be one of %s"), MUST_BE_SQUARE_MATRIX("'%s' must be a square matrix"), diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 27a1af0f79..a730415c2a 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -12661,6 +12661,26 @@ Loading required package: splines head foo2 ##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testCbind# +#cbind(1, 1:4, matrix(1:8, nrow=2)) + [,1] [,2] [,3] [,4] [,5] [,6] +[1,] 1 1 1 3 5 7 +[2,] 1 2 2 4 6 8 +Warning message: +In cbind(1, 1:4, matrix(1:8, nrow = 2)) : + number of rows of result is not a multiple of vector length (arg 2) + +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testCbind#Output.IgnoreWarningContext# +#cbind(1:2, 1:3, 1:4) + [,1] [,2] [,3] +[1,] 1 1 1 +[2,] 2 2 2 +[3,] 1 3 3 +[4,] 2 1 4 +Warning message: +In cbind(1:2, 1:3, 1:4) : + number of rows of result is not a multiple of vector length (arg 2) + +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testCbind# #cbind(55, character(0)) [,1] [1,] "55" @@ -12670,6 +12690,15 @@ head a [1,] "55" +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testCbind# +#cbind(array(1:8,c(2,4),list(c('x','y'), c('a','b', 'a2', 'b2'))), 1:8) + a b a2 b2 +x 1 3 5 7 1 +y 2 4 6 8 2 +Warning message: +In cbind(array(1:8, c(2, 4), list(c("x", "y"), c("a", "b", "a2", : + number of rows of result is not a multiple of vector length (arg 2) + ##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testCbind# #cbind(character(0)) [,1] @@ -12679,6 +12708,11 @@ head [,1] [1,] "f" +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testCbind# +#cbind(matrix(1:4,nrow=2), matrix(1:8,nrow=4)) +Error in cbind(matrix(1:4, nrow = 2), matrix(1:8, nrow = 4)) : + number of rows of matrices must match (see arg 2) + ##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testCbind# #v <- 1:3; attr(v, 'a') <- 'a'; attr(v, 'a1') <- 'a1'; cbind(v); cbind(v, v) v @@ -12914,6 +12948,42 @@ c 2 <NA> 2 <NA> 2 +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testDimnames# +#cbind(array(1:8,c(2,2),list(c('a','b'), c('d','e'))), array(1:4,c(2,2),list(c('f','g'), c('h','i')))) + d e h i +a 1 3 1 3 +b 2 4 2 4 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testDimnames#Output.IgnoreWarningContext# +#cbind(array(1:8,c(2,2),list(c('a','b'), c('d','e'))), array(1:8,c(2,2,2),list(c('a1','b1'), c('d1','e1'), c('f1','g1')))) + d e +a 1 3 1 +b 2 4 2 +Warning message: +In cbind(array(1:8, c(2, 2), list(c("a", "b"), c("d", "e"))), array(1:8, : + number of rows of result is not a multiple of vector length (arg 2) + +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testDimnames# +#cbind(array(1:8,c(2,2,2),list(c('a','b'), c('d','e'), c('f','g'))), array(1:8,c(2,2,2),list(c('a1','b1'), c('d1','e1'), c('f1','g1')))) + [,1] [,2] +[1,] 1 1 +[2,] 2 2 +[3,] 3 3 +[4,] 4 4 +[5,] 5 5 +[6,] 6 6 +[7,] 7 7 +[8,] 8 8 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testDimnames# +#cbind(array(1:8,c(2,4),list(c('x','y'), c('a','b', 'a2', 'b2'))), array(1:8,c(2,2,2),list(c('a1','b1'), c('d1','e1'), c('f1','g1')))) + a b a2 b2 +x 1 3 5 7 1 +y 2 4 6 8 2 +Warning message: +In cbind(array(1:8, c(2, 4), list(c("x", "y"), c("a", "b", "a2", : + number of rows of result is not a multiple of vector length (arg 2) + ##com.oracle.truffle.r.test.builtins.TestBuiltin_cbind.testDimnames# #{ attributes(cbind(1L)) } $dim diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_cbind.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_cbind.java index 4f598802c3..81f4608d39 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_cbind.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_cbind.java @@ -4,7 +4,7 @@ * http://www.gnu.org/licenses/gpl-2.0.html * * Copyright (c) 2012-2014, Purdue University - * Copyright (c) 2013, 2017, Oracle and/or its affiliates + * Copyright (c) 2013, 2018, Oracle and/or its affiliates * * All rights reserved. */ @@ -106,6 +106,11 @@ public class TestBuiltin_cbind extends TestBase { assertEval("v <- 1; attr(v, 'a') <- 'a'; attr(v, 'a1') <- 'a1'; cbind(v); cbind(v, v)"); assertEval("v <- 1:3; attr(v, 'a') <- 'a'; attr(v, 'a1') <- 'a1'; cbind(v); cbind(v, v)"); assertEval("v <- 1:3; v1<-1:3; attr(v, 'a') <- 'a'; attr(v1, 'a1') <- 'a1'; cbind(v, v1)"); + + assertEval("cbind(array(1:8,c(2,4),list(c('x','y'), c('a','b', 'a2', 'b2'))), 1:8)"); + assertEval("cbind(1, 1:4, matrix(1:8, nrow=2))"); + assertEval("cbind(matrix(1:4,nrow=2), matrix(1:8,nrow=4))"); + assertEval(Output.IgnoreWarningContext, "cbind(1:2, 1:3, 1:4)"); } @Test @@ -166,6 +171,11 @@ public class TestBuiltin_cbind extends TestBase { assertEval("{ attributes(cbind(integer(0), integer(0))) }"); assertEval("{ attributes(cbind(c(1), integer(0))) }"); assertEval("{ attributes(cbind(structure(1:4, dim=c(2,2), dimnames=list(y=c('y1', 'y2'), x=c('x1', 'x2'))), integer(0))) }"); + + assertEval("cbind(array(1:8,c(2,2,2),list(c('a','b'), c('d','e'), c('f','g'))), array(1:8,c(2,2,2),list(c('a1','b1'), c('d1','e1'), c('f1','g1'))))"); + assertEval("cbind(array(1:8,c(2,2),list(c('a','b'), c('d','e'))), array(1:4,c(2,2),list(c('f','g'), c('h','i'))))"); + assertEval(Output.IgnoreWarningContext, "cbind(array(1:8,c(2,2),list(c('a','b'), c('d','e'))), array(1:8,c(2,2,2),list(c('a1','b1'), c('d1','e1'), c('f1','g1'))))"); + assertEval("cbind(array(1:8,c(2,4),list(c('x','y'), c('a','b', 'a2', 'b2'))), array(1:8,c(2,2,2),list(c('a1','b1'), c('d1','e1'), c('f1','g1'))))"); } @Test -- GitLab