diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java index 6ee0bd71c9059fa8f27e81e34a64ce30e5612cbb..041f3d91b568f1c83d4d41349949e3c5edb974cd 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java @@ -12,21 +12,25 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import com.oracle.truffle.api.CompilerDirectives; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.object.DynamicObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.LoopConditionProfile; import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode; import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNodeGen; import com.oracle.truffle.r.nodes.attributes.InitAttributesNode; +import com.oracle.truffle.r.nodes.attributes.RemoveAttributeNode; import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.ExtractNamesAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.function.opt.ReuseNonSharedNode; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; @@ -37,8 +41,10 @@ import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory; import com.oracle.truffle.r.runtime.data.RList; +import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator; @@ -61,6 +67,8 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { @Child private ExtractNamesAttributeNode extractAxisNamesNode = ExtractNamesAttributeNode.create(); @Child private GetDimAttributeNode getDimNode = GetDimAttributeNode.create(); @Child private ReuseNonSharedNode reuseNonShared = ReuseNonSharedNode.create(); + @Child private GetNamesAttributeNode getNamesNode; + @Child private RemoveAttributeNode removeAttributeNode; static { Casts.noCasts(Transpose.class); @@ -146,7 +154,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { } } // don't need to set new dimensions; it is a square matrix - putNewDimNames(vector, vector); + convertDimNames(vector, vector); } @Specialization(guards = {"isSquare(x)", "!isRExpression(x)", "xReuse.supports(x)"}) @@ -192,7 +200,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { // copy attributes copyRegAttributes.execute(x, result); // set new dimensions - putNewDimensions(x, result, new int[]{secondDim, firstDim}); + putNewDimsFromDimnames(x, result, new int[]{secondDim, firstDim}); } result.setComplete(x.isComplete()); return result; @@ -208,17 +216,21 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { @Specialization(guards = {"!isMatrix(x)", "!isRExpression(x)"}) protected RVector<?> transposeNonMatrix(RAbstractVector x) { RVector<?> reused = reuseNonShared.execute(x); - putNewDimensions(reused, reused, new int[]{1, x.getLength()}); + putNewDimsFromNames(reused, reused, new int[]{1, x.getLength()}); return reused; + } + private void putNewDimsFromDimnames(RAbstractVector source, RAbstractVector dest, int[] newDim) { + putDimensions.execute(initAttributes.execute(dest), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR)); + convertDimNames(source, dest); } - private void putNewDimensions(RAbstractVector source, RAbstractVector dest, int[] newDim) { + private void putNewDimsFromNames(RAbstractVector source, RAbstractVector dest, int[] newDim) { putDimensions.execute(initAttributes.execute(dest), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR)); - putNewDimNames(source, dest); + convertNamesToDimnames(source, dest); } - private void putNewDimNames(RAbstractVector source, RAbstractVector dest) { + private void convertDimNames(RAbstractVector source, RAbstractVector dest) { // set new dim names RList dimNames = getDimNamesNode.getDimNames(source); if (dimNames != null) { @@ -231,6 +243,24 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { } } + private void convertNamesToDimnames(RAbstractVector source, RAbstractVector dest) { + if (getNamesNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + getNamesNode = insert(GetNamesAttributeNode.create()); + } + RAbstractStringVector names = (RAbstractStringVector) getNamesNode.execute(source); + if (names != null) { + RList newDimNames = RDataFactory.createList(new Object[]{RNull.instance, names}); + DynamicObject attributes = dest.getAttributes(); + putDimNames.execute(attributes, newDimNames); + if (removeAttributeNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + removeAttributeNode = insert(RemoveAttributeNode.create()); + } + removeAttributeNode.execute(attributes, "names"); + } + } + @Fallback protected RVector<?> transpose(@SuppressWarnings("unused") Object x) { throw error(Message.ARGUMENT_NOT_MATRIX); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java index ca5b37d80674154befc4b2ec214dc4a11909ea31..e85550eebe9790f47c6bfe8d339fa778967caa4f 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_t.java @@ -4,7 +4,7 @@ * http://www.gnu.org/licenses/gpl-2.0.html * * Copyright (c) 2012-2014, Purdue University - * Copyright (c) 2013, 2017, Oracle and/or its affiliates + * Copyright (c) 2013, 2018, Oracle and/or its affiliates * * All rights reserved. */ @@ -34,6 +34,8 @@ public class TestBuiltin_t extends TestBase { assertEval("{ t(1:3) }"); assertEval("{ t(t(t(1:3))) }"); + assertEval("{ x <- 1:3; names(x) <- c('a', 'b'); t(x) }"); + assertEval("{ t(matrix(1:6, nrow=2)) }"); assertEval("{ t(t(matrix(1:6, nrow=2))) }"); assertEval("{ t(matrix(1:4, nrow=2)) }");