diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java index 177768859206b1c9efe5fd9e697239ff55fffd5c..bd48e8980243abb1d75214c3cc5b599653657019 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java @@ -52,7 +52,7 @@ public abstract class Crossprod extends RBuiltinNode { return matMult.executeObject(op1, op2); } - private Object transpose(Object value) { + private Object transpose(RAbstractVector value) { if (transpose == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); transpose = insert(TransposeNodeGen.create(null)); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java index c1a394c20d9e60501af8df2d9a4d7e1e24272140..e9088eeb636daa0ba882e7913ceef4134f86132b 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java @@ -13,88 +13,86 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; -import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.SUBSTITUTE; +import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; +import java.util.function.BiFunction; +import java.util.function.Function; + +import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.LoopConditionProfile; import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode; import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNodeGen; import com.oracle.truffle.r.nodes.attributes.InitAttributesNode; import com.oracle.truffle.r.nodes.attributes.PutAttributeNode; import com.oracle.truffle.r.nodes.attributes.PutAttributeNodeGen; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RAttributeProfiles; +import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.nodes.RNode; -@RBuiltin(name = "t.default", kind = SUBSTITUTE, parameterNames = {"x"}, behavior = PURE) -// TODO INTERNAL +@RBuiltin(name = "t.default", kind = INTERNAL, parameterNames = {"x"}, behavior = PURE) public abstract class Transpose extends RBuiltinNode { private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); private final BranchProfile hasDimNamesProfile = BranchProfile.create(); private final ConditionProfile isMatrixProfile = ConditionProfile.createBinaryProfile(); + private final VectorLengthProfile lengthProfile = VectorLengthProfile.create(); + private final LoopConditionProfile loopProfile = LoopConditionProfile.createCountingProfile(); + @Child private CopyOfRegAttributesNode copyRegAttributes = CopyOfRegAttributesNodeGen.create(); @Child private InitAttributesNode initAttributes = InitAttributesNode.create(); @Child private PutAttributeNode putDimensions = PutAttributeNodeGen.createDim(); @Child private PutAttributeNode putDimNames = PutAttributeNodeGen.createDimNames(); - public abstract Object execute(Object o); - - @Specialization - protected RNull transpose(RNull value) { - return value; - } - - @Specialization - protected int transpose(int value) { - return value; - } - - @Specialization - protected double transpose(double value) { - return value; - } - - @Specialization - protected byte transpose(byte value) { - return value; - } - - @Specialization(guards = "isEmpty2D(vector)") - protected RAbstractVector transpose(RAbstractVector vector) { - int[] dim = vector.getDimensions(); - return vector.copyWithNewDimensions(new int[]{dim[1], dim[0]}); - } + public abstract Object execute(RAbstractVector o); @FunctionalInterface - private interface InnerLoop<T extends RAbstractVector> { - RVector apply(T vector, int firstDim); + private interface WriteArray<T extends RAbstractVector, A> { + void apply(A array, T vector, int i, int j); } - protected <T extends RAbstractVector> RVector transposeInternal(T vector, InnerLoop<T> innerLoop) { + protected <T extends RAbstractVector, A> RVector transposeInternal(T vector, Function<Integer, A> createArray, WriteArray<T, A> writeArray, BiFunction<A, Boolean, RVector> createResult) { + int length = lengthProfile.profile(vector.getLength()); int firstDim; int secondDim; if (isMatrixProfile.profile(vector.isMatrix())) { firstDim = vector.getDimensions()[0]; secondDim = vector.getDimensions()[1]; } else { - firstDim = vector.getLength(); + firstDim = length; secondDim = 1; } - RNode.reportWork(this, vector.getLength()); + RNode.reportWork(this, length); - RVector r = innerLoop.apply(vector, firstDim); + A array = createArray.apply(length); + int j = 0; + loopProfile.profileCounted(length); + for (int i = 0; loopProfile.inject(i < length); i++, j += firstDim) { + if (j > (length - 1)) { + j -= (length - 1); + } + writeArray.apply(array, vector, i, j); + } + RVector r = createResult.apply(array, vector.isComplete()); // copy attributes copyRegAttributes.execute(vector, r); // set new dimensions @@ -113,61 +111,47 @@ public abstract class Transpose extends RBuiltinNode { return r; } - private static RVector innerLoopInt(RAbstractIntVector vector, int firstDim) { - int[] result = new int[vector.getLength()]; - int j = 0; - for (int i = 0; i < result.length; i++, j += firstDim) { - if (j > (result.length - 1)) { - j -= (result.length - 1); - } - result[i] = vector.getDataAt(j); - } - return RDataFactory.createIntVector(result, vector.isComplete()); + @Specialization + protected RVector transpose(RAbstractIntVector x) { + return transposeInternal(x, l -> new int[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createIntVector); } - private static RVector innerLoopDouble(RAbstractDoubleVector vector, int firstDim) { - double[] result = new double[vector.getLength()]; - int j = 0; - for (int i = 0; i < result.length; i++, j += firstDim) { - if (j > (result.length - 1)) { - j -= (result.length - 1); - } - result[i] = vector.getDataAt(j); - } - return RDataFactory.createDoubleVector(result, vector.isComplete()); + @Specialization + protected RVector transpose(RAbstractLogicalVector x) { + return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createLogicalVector); } - private static RVector innerLoopString(RAbstractStringVector vector, int firstDim) { - String[] result = new String[vector.getLength()]; - int j = 0; - for (int i = 0; i < result.length; i++, j += firstDim) { - if (j > (result.length - 1)) { - j -= (result.length - 1); - } - result[i] = vector.getDataAt(j); - } - return RDataFactory.createStringVector(result, vector.isComplete()); + @Specialization + protected RVector transpose(RAbstractDoubleVector x) { + return transposeInternal(x, l -> new double[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createDoubleVector); } - @Specialization(guards = "!isEmpty2D(vector)") - protected RVector transpose(RAbstractIntVector vector) { - return transposeInternal(vector, Transpose::innerLoopInt); + @Specialization + protected RVector transpose(RAbstractComplexVector x) { + return transposeInternal(x, l -> new double[l * 2], (a, v, i, j) -> { + RComplex d = v.getDataAt(j); + a[i * 2] = d.getRealPart(); + a[i * 2 + 1] = d.getImaginaryPart(); + }, RDataFactory::createComplexVector); } - @Specialization(guards = "!isEmpty2D(vector)") - protected RVector transpose(RAbstractDoubleVector vector) { - return transposeInternal(vector, Transpose::innerLoopDouble); + @Specialization + protected RVector transpose(RAbstractStringVector x) { + return transposeInternal(x, l -> new String[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createStringVector); } - @Specialization(guards = "!isEmpty2D(vector)") - protected RVector transpose(RAbstractStringVector vector) { - return transposeInternal(vector, Transpose::innerLoopString); + @Specialization + protected RVector transpose(RAbstractListVector x) { + return transposeInternal(x, l -> new Object[l], (a, v, i, j) -> a[i] = v.getDataAt(j), (a, c) -> RDataFactory.createList(a)); } - protected static boolean isEmpty2D(RAbstractVector vector) { - if (!vector.hasDimensions()) { - return false; - } - return vector.getDimensions().length == 2 && vector.getLength() == 0; + @Specialization + protected RVector transpose(RAbstractRawVector x) { + return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getRawDataAt(j), (a, c) -> RDataFactory.createRawVector(a)); + } + + @Fallback + protected RVector transpose(@SuppressWarnings("unused") Object x) { + throw RError.error(RError.SHOW_CALLER, Message.ARGUMENT_NOT_MATRIX); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 011b955a5e0015d22f52d161aa80e492eaa8f5fb..4bfe99947dd762503e46c95f024839807cf5478f 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -47956,6 +47956,39 @@ integer(0) #{ u <- function() sys.parents() ; f <- function(x) x ; g <- function(y) f(y) ; h <- function(z=u()) g(z) ; h() } [1] 0 1 2 1 +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTranspose +#t(1) + [,1] +[1,] 1 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTranspose +#t(TRUE) + [,1] +[1,] TRUE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTranspose +#t(as.raw(c(1,2,3,4))) + [,1] [,2] [,3] [,4] +[1,] 01 02 03 04 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTranspose +#t(new.env()) +Error in t.default(new.env()) : argument is not a matrix + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTranspose +#v <- as.complex(1:50); dim(v) <- c(5,10); dimnames(v) <- list(as.character(40:44), as.character(10:19)); t(v) + 40 41 42 43 44 +10 1+0i 2+0i 3+0i 4+0i 5+0i +11 6+0i 7+0i 8+0i 9+0i 10+0i +12 11+0i 12+0i 13+0i 14+0i 15+0i +13 16+0i 17+0i 18+0i 19+0i 20+0i +14 21+0i 22+0i 23+0i 24+0i 25+0i +15 26+0i 27+0i 28+0i 29+0i 30+0i +16 31+0i 32+0i 33+0i 34+0i 35+0i +17 36+0i 37+0i 38+0i 39+0i 40+0i +18 41+0i 42+0i 43+0i 44+0i 45+0i +19 46+0i 47+0i 48+0i 49+0i 50+0i + ##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTranspose #{ m <- double() ; dim(m) <- c(0,4) ; t(m) } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java index 3d851f24c7e6da980f8521ef2ff438e066cf4420..693cde32eb6e252583908b569377a1a8b142a8c2 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java @@ -40,5 +40,11 @@ public class TestBuiltin_t extends TestBase { assertEval("{ t(t(matrix(1:4, nrow=2))) }"); assertEval("{ x<-matrix(1:2, ncol=2, dimnames=list(\"a\", c(\"b\", \"c\"))); t(x) }"); + + assertEval("t(new.env())"); + assertEval("v <- as.complex(1:50); dim(v) <- c(5,10); dimnames(v) <- list(as.character(40:44), as.character(10:19)); t(v)"); + assertEval("t(1)"); + assertEval("t(TRUE)"); + assertEval("t(as.raw(c(1,2,3,4)))"); } }