From 4f5fa3a5fe0b7ea790d30d07ece2c8a228fd32f6 Mon Sep 17 00:00:00 2001
From: Florian Angerer <florian.angerer@oracle.com>
Date: Tue, 22 Aug 2017 16:23:09 +0200
Subject: [PATCH] Implemented in-place transpose for simple cases.

---
 .../r/nodes/builtin/base/Transpose.java       | 127 +++++++++++++-----
 .../truffle/r/test/ExpectedTestOutput.test    |  32 +++++
 .../r/test/builtins/TestBuiltin_t.java        |  11 ++
 3 files changed, 140 insertions(+), 30 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java
index abad492f9e..7c650783f2 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java
@@ -32,6 +32,7 @@ import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAt
 import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
 import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
+import com.oracle.truffle.r.nodes.function.opt.ReuseNonSharedNode;
 import com.oracle.truffle.r.nodes.profile.VectorLengthProfile;
 import com.oracle.truffle.r.runtime.RError.Message;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
@@ -55,6 +56,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
 
     private final BranchProfile hasDimNamesProfile = BranchProfile.create();
     private final ConditionProfile isMatrixProfile = ConditionProfile.createBinaryProfile();
+    private final BranchProfile isNonSharedProfile = BranchProfile.create();
 
     private final VectorLengthProfile lengthProfile = VectorLengthProfile.create();
     private final LoopConditionProfile loopProfile = LoopConditionProfile.createCountingProfile();
@@ -66,6 +68,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
     @Child private GetDimNamesAttributeNode getDimNamesNode = GetDimNamesAttributeNode.create();
     @Child private GetNamesAttributeNode getAxisNamesNode = GetNamesAttributeNode.create();
     @Child private GetDimAttributeNode getDimNode;
+    @Child private ReuseNonSharedNode reuseNonShared = ReuseNonSharedNode.create();
 
     static {
         Casts.noCasts(Transpose.class);
@@ -78,22 +81,20 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
         void apply(A array, T vector, int i, int j);
     }
 
+    @FunctionalInterface
+    private interface Swap {
+        /** Swap element at (i, j) with element at (j, i). */
+        void swap(int i, int j);
+    }
+
     protected <T extends RAbstractVector, A> RVector<?> transposeInternal(T vector, Function<Integer, A> createArray, WriteArray<T, A> writeArray, BiFunction<A, Boolean, RVector<?>> createResult) {
         int length = lengthProfile.profile(vector.getLength());
         int firstDim;
         int secondDim;
-        if (isMatrixProfile.profile(vector.isMatrix())) {
-            if (getDimNode == null) {
-                CompilerDirectives.transferToInterpreterAndInvalidate();
-                getDimNode = insert(GetDimAttributeNode.create());
-            }
-            int[] dims = getDimNode.getDimensions(vector);
-            firstDim = dims[0];
-            secondDim = dims[1];
-        } else {
-            firstDim = length;
-            secondDim = 1;
-        }
+        assert vector.isMatrix();
+        int[] dims = getDimensions(vector);
+        firstDim = dims[0];
+        secondDim = dims[1];
         RBaseNode.reportWork(this, length);
 
         A array = createArray.apply(length);
@@ -110,36 +111,75 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
         copyRegAttributes.execute(vector, r);
         // set new dimensions
         int[] newDim = new int[]{secondDim, firstDim};
-        putDimensions.execute(initAttributes.execute(r), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR));
-        // set new dim names
-        RList dimNames = getDimNamesNode.getDimNames(vector);
-        if (dimNames != null) {
-            hasDimNamesProfile.enter();
-            assert dimNames.getLength() == 2;
-            RStringVector axisNames = getAxisNamesNode.getNames(dimNames);
-            RStringVector transAxisNames = axisNames == null ? null : RDataFactory.createStringVector(new String[]{axisNames.getDataAt(1), axisNames.getDataAt(0)}, true);
-            RList newDimNames = RDataFactory.createList(new Object[]{dimNames.getDataAt(1), dimNames.getDataAt(0)}, transAxisNames);
-            putDimNames.execute(r.getAttributes(), newDimNames);
-        }
+        putNewDimensions(vector, r, newDim);
         return r;
     }
 
-    @Specialization
+    protected RVector<?> transposeSquareMatrixInPlace(RVector<?> vector, Swap swapper) {
+        int length = lengthProfile.profile(vector.getLength());
+        assert vector.isMatrix();
+        int[] dims = getDimensions(vector);
+        assert dims.length == 2;
+        assert dims[0] == dims[1];
+        int dim = dims[0];
+        RBaseNode.reportWork(this, length);
+
+        loopProfile.profileCounted(length);
+        for (int i = 0; loopProfile.inject(i < dim); i++) {
+            for (int j = 0; j < i; j++) {
+                swapper.swap(i * dim + j, j * dim + i);
+            }
+        }
+        // don't need to set new dimensions; it is a square matrix
+        putNewDimNames(vector, vector);
+        return vector;
+    }
+
+    private int[] getDimensions(RAbstractVector vector) {
+        assert vector.isMatrix();
+        if (getDimNode == null) {
+            CompilerDirectives.transferToInterpreterAndInvalidate();
+            getDimNode = insert(GetDimAttributeNode.create());
+        }
+        return getDimNode.getDimensions(vector);
+    }
+
+    protected boolean isSquare(RAbstractVector vector) {
+        if (vector.isMatrix()) {
+            int[] dims = getDimensions(vector);
+            assert dims.length >= 2;
+            return dims[0] == dims[1];
+        }
+        return false;
+    }
+
+    @Specialization(guards = "isSquare(x)")
+    protected RVector<?> transposeSquare(RAbstractIntVector x) {
+        RVector<?> execute = reuseNonShared.execute(x);
+        int[] internalStore = (int[]) execute.getInternalStore();
+        return transposeSquareMatrixInPlace(execute, (i, j) -> {
+            int tmp = internalStore[i];
+            internalStore[i] = internalStore[j];
+            internalStore[j] = tmp;
+        });
+    }
+
+    @Specialization(guards = "x.isMatrix()")
     protected RVector<?> transpose(RAbstractIntVector x) {
         return transposeInternal(x, l -> new int[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createIntVector);
     }
 
-    @Specialization
+    @Specialization(guards = "x.isMatrix()")
     protected RVector<?> transpose(RAbstractLogicalVector x) {
         return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createLogicalVector);
     }
 
-    @Specialization
+    @Specialization(guards = "x.isMatrix()")
     protected RVector<?> transpose(RAbstractDoubleVector x) {
         return transposeInternal(x, l -> new double[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createDoubleVector);
     }
 
-    @Specialization
+    @Specialization(guards = "x.isMatrix()")
     protected RVector<?> transpose(RAbstractComplexVector x) {
         return transposeInternal(x, l -> new double[l * 2], (a, v, i, j) -> {
             RComplex d = v.getDataAt(j);
@@ -148,21 +188,48 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
         }, RDataFactory::createComplexVector);
     }
 
-    @Specialization
+    @Specialization(guards = "x.isMatrix()")
     protected RVector<?> transpose(RAbstractStringVector x) {
         return transposeInternal(x, l -> new String[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createStringVector);
     }
 
-    @Specialization
+    @Specialization(guards = "x.isMatrix()")
     protected RVector<?> transpose(RAbstractListVector x) {
         return transposeInternal(x, l -> new Object[l], (a, v, i, j) -> a[i] = v.getDataAt(j), (a, c) -> RDataFactory.createList(a));
     }
 
-    @Specialization
+    @Specialization(guards = "x.isMatrix()")
     protected RVector<?> transpose(RAbstractRawVector x) {
         return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getRawDataAt(j), (a, c) -> RDataFactory.createRawVector(a));
     }
 
+    @Specialization(guards = "!x.isMatrix()")
+    protected RVector<?> transpose(RAbstractVector x) {
+        RVector<?> reused = reuseNonShared.execute(x);
+        putNewDimensions(reused, reused, new int[]{1, x.getLength()});
+        return reused;
+
+    }
+
+    private void putNewDimensions(RAbstractVector source, RVector<?> dest, int[] newDim) {
+        putDimensions.execute(initAttributes.execute(dest), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR));
+        putNewDimNames(source, dest);
+    }
+
+    private void putNewDimNames(RAbstractVector source, RVector<?> dest) {
+        // set new dim names
+        RList dimNames = getDimNamesNode.getDimNames(source);
+        if (dimNames != null) {
+            hasDimNamesProfile.enter();
+            assert dimNames.getLength() == 2;
+            RStringVector axisNames = getAxisNamesNode.getNames(dimNames);
+            RStringVector transAxisNames = axisNames == null ? null : RDataFactory.createStringVector(new String[]{axisNames.getDataAt(1), axisNames.getDataAt(0)}, true);
+            RList newDimNames = RDataFactory.createList(new Object[]{dimNames.getDataAt(1),
+                            dimNames.getDataAt(0)}, transAxisNames);
+            putDimNames.execute(dest.getAttributes(), newDimNames);
+        }
+    }
+
     @Fallback
     protected RVector<?> transpose(@SuppressWarnings("unused") Object x) {
         throw error(Message.ARGUMENT_NOT_MATRIX);
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
index 085cc297b2..933cf6864b 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
@@ -69567,6 +69567,38 @@ Error in t.default(new.env()) : argument is not a matrix
 b 1
 c 2
 
+##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare#
+#{ m <- matrix(1:64, 8, 8) ; sum(m * t(m)) }
+[1] 72976
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare#
+#{ m <- matrix(as.raw(c(1,2,3,4)), 2, 2); t(m) }
+     [,1] [,2]
+[1,]   01   02
+[2,]   03   04
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare#
+#{ m <- matrix(c('1', '2', '3', '4'), 2, 2); t(m) }
+     [,1] [,2]
+[1,] "1"  "2"
+[2,] "3"  "4"
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare#
+#{ m <- matrix(c(T, T, F, F), 2, 2); t(m) }
+      [,1]  [,2]
+[1,]  TRUE  TRUE
+[2,] FALSE FALSE
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare#
+#{ m <- matrix(list(a=1,b=2,c=3,d=4), 2, 2); t(m) }
+     [,1] [,2]
+[1,] 1    2
+[2,] 3    4
+
+##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testTransposeSquare#
+#{ m <- matrix(seq(0.01,0.64,0.01), 8, 8) ; sum(m * t(m)) }
+[1] 7.2976
+
 ##com.oracle.truffle.r.test.builtins.TestBuiltin_t.testt1#
 #argv <- structure(list(x = c(-2.13777446721376, 1.17045456767922,     5.85180137819007)), .Names = 'x');do.call('t', argv)
           [,1]     [,2]     [,3]
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java
index b4640e1d05..d1708c7290 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java
@@ -48,4 +48,15 @@ public class TestBuiltin_t extends TestBase {
         assertEval("t(as.raw(c(1,2,3,4)))");
         assertEval("t(matrix(1:6, 3, 2, dimnames=list(x=c(\"x1\",\"x2\",\"x3\"),y=c(\"y1\",\"y2\"))))");
     }
+
+    @Test
+    public void testTransposeSquare() {
+        // test square matrices
+        assertEval("{ m <- matrix(1:64, 8, 8) ; sum(m * t(m)) }");
+        assertEval("{ m <- matrix(seq(0.01,0.64,0.01), 8, 8) ; sum(m * t(m)) }");
+        assertEval("{ m <- matrix(c(T, T, F, F), 2, 2); t(m) }");
+        assertEval("{ m <- matrix(c('1', '2', '3', '4'), 2, 2); t(m) }");
+        assertEval("{ m <- matrix(as.raw(c(1,2,3,4)), 2, 2); t(m) }");
+        assertEval("{ m <- matrix(list(a=1,b=2,c=3,d=4), 2, 2); t(m) }");
+    }
 }
-- 
GitLab