Skip to content
Snippets Groups Projects
Commit 584daaaf authored by Tomas Stupka's avatar Tomas Stupka
Browse files

t builtin should convert names attribute to dimnames if source is 1-dim vector

parent 0eb7f440
No related branches found
No related tags found
No related merge requests found
......@@ -12,21 +12,25 @@
*/
package com.oracle.truffle.r.nodes.builtin.base;
import com.oracle.truffle.api.CompilerDirectives;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.object.DynamicObject;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode;
import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNodeGen;
import com.oracle.truffle.r.nodes.attributes.InitAttributesNode;
import com.oracle.truffle.r.nodes.attributes.RemoveAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.ExtractNamesAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.function.opt.ReuseNonSharedNode;
import com.oracle.truffle.r.nodes.profile.VectorLengthProfile;
......@@ -37,8 +41,10 @@ import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
......@@ -61,6 +67,8 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
@Child private ExtractNamesAttributeNode extractAxisNamesNode = ExtractNamesAttributeNode.create();
@Child private GetDimAttributeNode getDimNode = GetDimAttributeNode.create();
@Child private ReuseNonSharedNode reuseNonShared = ReuseNonSharedNode.create();
@Child private GetNamesAttributeNode getNamesNode;
@Child private RemoveAttributeNode removeAttributeNode;
static {
Casts.noCasts(Transpose.class);
......@@ -146,7 +154,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
}
}
// don't need to set new dimensions; it is a square matrix
putNewDimNames(vector, vector);
convertDimNames(vector, vector);
}
@Specialization(guards = {"isSquare(x)", "!isRExpression(x)", "xReuse.supports(x)"})
......@@ -192,7 +200,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
// copy attributes
copyRegAttributes.execute(x, result);
// set new dimensions
putNewDimensions(x, result, new int[]{secondDim, firstDim});
putNewDimsFromDimnames(x, result, new int[]{secondDim, firstDim});
}
result.setComplete(x.isComplete());
return result;
......@@ -208,17 +216,21 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
@Specialization(guards = {"!isMatrix(x)", "!isRExpression(x)"})
protected RVector<?> transposeNonMatrix(RAbstractVector x) {
RVector<?> reused = reuseNonShared.execute(x);
putNewDimensions(reused, reused, new int[]{1, x.getLength()});
putNewDimsFromNames(reused, reused, new int[]{1, x.getLength()});
return reused;
}
private void putNewDimsFromDimnames(RAbstractVector source, RAbstractVector dest, int[] newDim) {
putDimensions.execute(initAttributes.execute(dest), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR));
convertDimNames(source, dest);
}
private void putNewDimensions(RAbstractVector source, RAbstractVector dest, int[] newDim) {
private void putNewDimsFromNames(RAbstractVector source, RAbstractVector dest, int[] newDim) {
putDimensions.execute(initAttributes.execute(dest), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR));
putNewDimNames(source, dest);
convertNamesToDimnames(source, dest);
}
private void putNewDimNames(RAbstractVector source, RAbstractVector dest) {
private void convertDimNames(RAbstractVector source, RAbstractVector dest) {
// set new dim names
RList dimNames = getDimNamesNode.getDimNames(source);
if (dimNames != null) {
......@@ -231,6 +243,24 @@ public abstract class Transpose extends RBuiltinNode.Arg1 {
}
}
private void convertNamesToDimnames(RAbstractVector source, RAbstractVector dest) {
if (getNamesNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
getNamesNode = insert(GetNamesAttributeNode.create());
}
RAbstractStringVector names = (RAbstractStringVector) getNamesNode.execute(source);
if (names != null) {
RList newDimNames = RDataFactory.createList(new Object[]{RNull.instance, names});
DynamicObject attributes = dest.getAttributes();
putDimNames.execute(attributes, newDimNames);
if (removeAttributeNode == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
removeAttributeNode = insert(RemoveAttributeNode.create());
}
removeAttributeNode.execute(attributes, "names");
}
}
@Fallback
protected RVector<?> transpose(@SuppressWarnings("unused") Object x) {
throw error(Message.ARGUMENT_NOT_MATRIX);
......
......@@ -4,7 +4,7 @@
* http://www.gnu.org/licenses/gpl-2.0.html
*
* Copyright (c) 2012-2014, Purdue University
* Copyright (c) 2013, 2017, Oracle and/or its affiliates
* Copyright (c) 2013, 2018, Oracle and/or its affiliates
*
* All rights reserved.
*/
......@@ -34,6 +34,8 @@ public class TestBuiltin_t extends TestBase {
assertEval("{ t(1:3) }");
assertEval("{ t(t(t(1:3))) }");
assertEval("{ x <- 1:3; names(x) <- c('a', 'b'); t(x) }");
assertEval("{ t(matrix(1:6, nrow=2)) }");
assertEval("{ t(t(matrix(1:6, nrow=2))) }");
assertEval("{ t(matrix(1:4, nrow=2)) }");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment