From 4f5fa3a5fe0b7ea790d30d07ece2c8a228fd32f6 Mon Sep 17 00:00:00 2001 From: Florian Angerer <florian.angerer@oracle.com> Date: Tue, 22 Aug 2017 16:23:09 +0200 Subject: [PATCH] Implemented in-place transpose for simple cases. --- .../r/nodes/builtin/base/Transpose.java | 127 +++++++++++++----- .../truffle/r/test/ExpectedTestOutput.test | 32 +++++ .../r/test/builtins/TestBuiltin_t.java | 11 ++ 3 files changed, 140 insertions(+), 30 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java index abad492f9e..7c650783f2 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java @@ -32,6 +32,7 @@ import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAt import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.function.opt.ReuseNonSharedNode; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.builtins.RBuiltin; @@ -55,6 +56,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { private final BranchProfile hasDimNamesProfile = BranchProfile.create(); private final ConditionProfile isMatrixProfile = ConditionProfile.createBinaryProfile(); + private final BranchProfile isNonSharedProfile = BranchProfile.create(); private final VectorLengthProfile lengthProfile = VectorLengthProfile.create(); private final LoopConditionProfile loopProfile = LoopConditionProfile.createCountingProfile(); @@ -66,6 +68,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { @Child private GetDimNamesAttributeNode getDimNamesNode = GetDimNamesAttributeNode.create(); @Child private GetNamesAttributeNode getAxisNamesNode = GetNamesAttributeNode.create(); @Child private GetDimAttributeNode getDimNode; + @Child private ReuseNonSharedNode reuseNonShared = ReuseNonSharedNode.create(); static { Casts.noCasts(Transpose.class); @@ -78,22 +81,20 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { void apply(A array, T vector, int i, int j); } + @FunctionalInterface + private interface Swap { + /** Swap element at (i, j) with element at (j, i). */ + void swap(int i, int j); + } + protected <T extends RAbstractVector, A> RVector<?> transposeInternal(T vector, Function<Integer, A> createArray, WriteArray<T, A> writeArray, BiFunction<A, Boolean, RVector<?>> createResult) { int length = lengthProfile.profile(vector.getLength()); int firstDim; int secondDim; - if (isMatrixProfile.profile(vector.isMatrix())) { - if (getDimNode == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - getDimNode = insert(GetDimAttributeNode.create()); - } - int[] dims = getDimNode.getDimensions(vector); - firstDim = dims[0]; - secondDim = dims[1]; - } else { - firstDim = length; - secondDim = 1; - } + assert vector.isMatrix(); + int[] dims = getDimensions(vector); + firstDim = dims[0]; + secondDim = dims[1]; RBaseNode.reportWork(this, length); A array = createArray.apply(length); @@ -110,36 +111,75 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { copyRegAttributes.execute(vector, r); // set new dimensions int[] newDim = new int[]{secondDim, firstDim}; - putDimensions.execute(initAttributes.execute(r), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR)); - // set new dim names - RList dimNames = getDimNamesNode.getDimNames(vector); - if (dimNames != null) { - hasDimNamesProfile.enter(); - assert dimNames.getLength() == 2; - RStringVector axisNames = getAxisNamesNode.getNames(dimNames); - RStringVector transAxisNames = axisNames == null ? null : RDataFactory.createStringVector(new String[]{axisNames.getDataAt(1), axisNames.getDataAt(0)}, true); - RList newDimNames = RDataFactory.createList(new Object[]{dimNames.getDataAt(1), dimNames.getDataAt(0)}, transAxisNames); - putDimNames.execute(r.getAttributes(), newDimNames); - } + putNewDimensions(vector, r, newDim); return r; } - @Specialization + protected RVector<?> transposeSquareMatrixInPlace(RVector<?> vector, Swap swapper) { + int length = lengthProfile.profile(vector.getLength()); + assert vector.isMatrix(); + int[] dims = getDimensions(vector); + assert dims.length == 2; + assert dims[0] == dims[1]; + int dim = dims[0]; + RBaseNode.reportWork(this, length); + + loopProfile.profileCounted(length); + for (int i = 0; loopProfile.inject(i < dim); i++) { + for (int j = 0; j < i; j++) { + swapper.swap(i * dim + j, j * dim + i); + } + } + // don't need to set new dimensions; it is a square matrix + putNewDimNames(vector, vector); + return vector; + } + + private int[] getDimensions(RAbstractVector vector) { + assert vector.isMatrix(); + if (getDimNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + getDimNode = insert(GetDimAttributeNode.create()); + } + return getDimNode.getDimensions(vector); + } + + protected boolean isSquare(RAbstractVector vector) { + if (vector.isMatrix()) { + int[] dims = getDimensions(vector); + assert dims.length >= 2; + return dims[0] == dims[1]; + } + return false; + } + + @Specialization(guards = "isSquare(x)") + protected RVector<?> transposeSquare(RAbstractIntVector x) { + RVector<?> execute = reuseNonShared.execute(x); + int[] internalStore = (int[]) execute.getInternalStore(); + return transposeSquareMatrixInPlace(execute, (i, j) -> { + int tmp = internalStore[i]; + internalStore[i] = internalStore[j]; + internalStore[j] = tmp; + }); + } + + @Specialization(guards = "x.isMatrix()") protected RVector<?> transpose(RAbstractIntVector x) { return transposeInternal(x, l -> new int[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createIntVector); } - @Specialization + @Specialization(guards = "x.isMatrix()") protected RVector<?> transpose(RAbstractLogicalVector x) { return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createLogicalVector); } - @Specialization + @Specialization(guards = "x.isMatrix()") protected RVector<?> transpose(RAbstractDoubleVector x) { return transposeInternal(x, l -> new double[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createDoubleVector); } - @Specialization + @Specialization(guards = "x.isMatrix()") protected RVector<?> transpose(RAbstractComplexVector x) { return transposeInternal(x, l -> new double[l * 2], (a, v, i, j) -> { RComplex d = v.getDataAt(j); @@ -148,21 +188,48 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { }, RDataFactory::createComplexVector); } - @Specialization + @Specialization(guards = "x.isMatrix()") protected RVector<?> transpose(RAbstractStringVector x) { return transposeInternal(x, l -> new String[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createStringVector); } - @Specialization + @Specialization(guards = "x.isMatrix()") protected RVector<?> transpose(RAbstractListVector x) { return transposeInternal(x, l -> new Object[l], (a, v, i, j) -> a[i] = v.getDataAt(j), (a, c) -> RDataFactory.createList(a)); } - @Specialization + @Specialization(guards = "x.isMatrix()") protected RVector<?> transpose(RAbstractRawVector x) { return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getRawDataAt(j), (a, c) -> RDataFactory.createRawVector(a)); } + @Specialization(guards = "!x.isMatrix()") + protected RVector<?> transpose(RAbstractVector x) { + RVector<?> reused = reuseNonShared.execute(x); + putNewDimensions(reused, reused, new int[]{1, x.getLength()}); + return reused; + + } + + private void putNewDimensions(RAbstractVector source, RVector<?> dest, int[] newDim) { + putDimensions.execute(initAttributes.execute(dest), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR)); + putNewDimNames(source, dest); + } + + private void putNewDimNames(RAbstractVector source, RVector<?> dest) { + // set new dim names + RList dimNames = getDimNamesNode.getDimNames(source); + if (dimNames != null) { + hasDimNamesProfile.enter(); + assert dimNames.getLength() == 2; + RStringVector axisNames = getAxisNamesNode.getNames(dimNames); + RStringVector transAxisNames = axisNames == null ? null : RDataFactory.createStringVector(new String[]{axisNames.getDataAt(1), axisNames.getDataAt(0)}, true); + RList newDimNames = RDataFactory.createList(new Object[]{dimNames.getDataAt(1), + dimNames.getDataAt(0)}, transAxisNames); + putDimNames.execute(dest.getAttributes(), newDimNames); + } + } + @Fallback protected RVector<?> transpose(@SuppressWarnings("unused") Object x) { throw error(Message.ARGUMENT_NOT_MATRIX); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 085cc297b2..933cf6864b 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -69567,6 +69567,38 @@ Error in t.default(new.env()) : argument is not a matrix b 1 c 2 +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare# +#{ m <- matrix(1:64, 8, 8) ; sum(m * t(m)) } +[1] 72976 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare# +#{ m <- matrix(as.raw(c(1,2,3,4)), 2, 2); t(m) } + [,1] [,2] +[1,] 01 02 +[2,] 03 04 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare# +#{ m <- matrix(c('1', '2', '3', '4'), 2, 2); t(m) } + [,1] [,2] +[1,] "1" "2" +[2,] "3" "4" + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare# +#{ m <- matrix(c(T, T, F, F), 2, 2); t(m) } + [,1] [,2] +[1,] TRUE TRUE +[2,] FALSE FALSE + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare# +#{ m <- matrix(list(a=1,b=2,c=3,d=4), 2, 2); t(m) } + [,1] [,2] +[1,] 1 2 +[2,] 3 4 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare# +#{ m <- matrix(seq(0.01,0.64,0.01), 8, 8) ; sum(m * t(m)) } +[1] 7.2976 + ##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testt1# #argv <- structure(list(x = c(-2.13777446721376, 1.17045456767922, 5.85180137819007)), .Names = 'x');do.call('t', argv) [,1] [,2] [,3] diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java index b4640e1d05..d1708c7290 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java @@ -48,4 +48,15 @@ public class TestBuiltin_t extends TestBase { assertEval("t(as.raw(c(1,2,3,4)))"); assertEval("t(matrix(1:6, 3, 2, dimnames=list(x=c(\"x1\",\"x2\",\"x3\"),y=c(\"y1\",\"y2\"))))"); } + + @Test + public void testTransposeSquare() { + // test square matrices + assertEval("{ m <- matrix(1:64, 8, 8) ; sum(m * t(m)) }"); + assertEval("{ m <- matrix(seq(0.01,0.64,0.01), 8, 8) ; sum(m * t(m)) }"); + assertEval("{ m <- matrix(c(T, T, F, F), 2, 2); t(m) }"); + assertEval("{ m <- matrix(c('1', '2', '3', '4'), 2, 2); t(m) }"); + assertEval("{ m <- matrix(as.raw(c(1,2,3,4)), 2, 2); t(m) }"); + assertEval("{ m <- matrix(list(a=1,b=2,c=3,d=4), 2, 2); t(m) }"); + } } -- GitLab