Skip to content
Snippets Groups Projects
Commit ef439641 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

improve Transpose (better treatment of attributes)

parent e571fa8c
No related branches found
No related tags found
No related merge requests found
......@@ -14,18 +14,20 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.runtime.RBuiltinKind.SUBSTITUTE;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.BranchProfile;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNode;
import com.oracle.truffle.r.nodes.attributes.CopyOfRegAttributesNodeGen;
import com.oracle.truffle.r.nodes.attributes.InitAttributesNode;
import com.oracle.truffle.r.nodes.attributes.PutAttributeNode;
import com.oracle.truffle.r.nodes.attributes.PutAttributeNodeGen;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RBuiltin;
import com.oracle.truffle.r.runtime.data.RAttributeProfiles;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
......@@ -39,6 +41,12 @@ public abstract class Transpose extends RBuiltinNode {
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
private final BranchProfile hasDimNamesProfile = BranchProfile.create();
private final ConditionProfile isMatrixProfile = ConditionProfile.createBinaryProfile();
@Child private CopyOfRegAttributesNode copyRegAttributes = CopyOfRegAttributesNodeGen.create();
@Child private InitAttributesNode initAttributes = InitAttributesNode.create();
@Child private PutAttributeNode putDimensions = PutAttributeNodeGen.createDim();
@Child private PutAttributeNode putDimNames = PutAttributeNodeGen.createDimNames();
public abstract Object execute(Object o);
......@@ -68,91 +76,91 @@ public abstract class Transpose extends RBuiltinNode {
return vector.copyWithNewDimensions(new int[]{dim[1], dim[0]});
}
@Specialization(guards = "!isEmpty2D(vector)")
protected RIntVector transpose(RAbstractIntVector vector) {
return performAbstractIntVector(vector, vector.isMatrix() ? vector.getDimensions() : new int[]{vector.getLength(), 1});
@FunctionalInterface
private interface InnerLoop<T extends RAbstractVector> {
RVector apply(T vector, int firstDim);
}
private RIntVector performAbstractIntVector(RAbstractIntVector vector, int[] dim) {
int firstDim = dim[0]; // rows
int secondDim = dim[1];
protected <T extends RAbstractVector> RVector transposeInternal(T vector, InnerLoop<T> innerLoop) {
int firstDim;
int secondDim;
if (isMatrixProfile.profile(vector.isMatrix())) {
firstDim = vector.getDimensions()[0];
secondDim = vector.getDimensions()[1];
} else {
firstDim = vector.getLength();
secondDim = 1;
}
RNode.reportWork(this, vector.getLength());
RVector r = innerLoop.apply(vector, firstDim);
// copy attributes
copyRegAttributes.execute(vector, r);
// set new dimensions
int[] newDim = new int[]{secondDim, firstDim};
r.setInternalDimensions(newDim);
putDimensions.execute(initAttributes.execute(r), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR));
// set new dim names
RList dimNames = vector.getDimNames(attrProfiles);
if (dimNames != null) {
hasDimNamesProfile.enter();
assert dimNames.getLength() == 2;
RList newDimNames = RDataFactory.createList(new Object[]{dimNames.getDataAt(1), dimNames.getDataAt(0)});
r.setInternalDimNames(newDimNames);
putDimNames.execute(r.getAttributes(), newDimNames);
}
return r;
}
private static RVector innerLoopInt(RAbstractIntVector vector, int firstDim) {
int[] result = new int[vector.getLength()];
int j = 0;
RNode.reportWork(this, vector.getLength());
for (int i = 0; i < result.length; i++, j += firstDim) {
if (j > (result.length - 1)) {
j -= (result.length - 1);
}
result[i] = vector.getDataAt(j);
}
int[] newDim = new int[]{secondDim, firstDim};
RIntVector r = RDataFactory.createIntVector(result, vector.isComplete());
r.copyAttributesFrom(attrProfiles, vector);
r.setDimensions(newDim);
setDimNames(r, vector);
return r;
return RDataFactory.createIntVector(result, vector.isComplete());
}
@Specialization(guards = "!isEmpty2D(vector)")
protected RDoubleVector transpose(RAbstractDoubleVector vector) {
return performAbstractDoubleVector(vector, vector.isMatrix() ? vector.getDimensions() : new int[]{vector.getLength(), 1});
}
private RDoubleVector performAbstractDoubleVector(RAbstractDoubleVector vector, int[] dim) {
int firstDim = dim[0];
int secondDim = dim[1];
private static RVector innerLoopDouble(RAbstractDoubleVector vector, int firstDim) {
double[] result = new double[vector.getLength()];
int j = 0;
RNode.reportWork(this, vector.getLength());
for (int i = 0; i < result.length; i++, j += firstDim) {
if (j > (result.length - 1)) {
j -= (result.length - 1);
}
result[i] = vector.getDataAt(j);
}
int[] newDim = new int[]{secondDim, firstDim};
RDoubleVector r = RDataFactory.createDoubleVector(result, vector.isComplete());
r.copyAttributesFrom(attrProfiles, vector);
r.setDimensions(newDim);
setDimNames(r, vector);
return r;
}
@Specialization(guards = "!isEmpty2D(vector)")
protected RStringVector transpose(RAbstractStringVector vector) {
return performAbstractStringVector(vector, vector.isMatrix() ? vector.getDimensions() : new int[]{vector.getLength(), 1});
return RDataFactory.createDoubleVector(result, vector.isComplete());
}
private RStringVector performAbstractStringVector(RAbstractStringVector vector, int[] dim) {
int firstDim = dim[0];
int secondDim = dim[1];
private static RVector innerLoopString(RAbstractStringVector vector, int firstDim) {
String[] result = new String[vector.getLength()];
int j = 0;
RNode.reportWork(this, vector.getLength());
for (int i = 0; i < result.length; i++, j += firstDim) {
if (j > (result.length - 1)) {
j -= (result.length - 1);
}
result[i] = vector.getDataAt(j);
}
int[] newDim = new int[]{secondDim, firstDim};
RStringVector r = RDataFactory.createStringVector(result, vector.isComplete());
r.copyAttributesFrom(attrProfiles, vector);
r.setDimensions(newDim);
setDimNames(r, vector);
return r;
return RDataFactory.createStringVector(result, vector.isComplete());
}
private void setDimNames(RVector newVector, RAbstractVector oldVector) {
RList dimNames = oldVector.getDimNames(attrProfiles);
if (dimNames != null) {
hasDimNamesProfile.enter();
assert dimNames.getLength() == 2;
newVector.setDimNames(RDataFactory.createList(new Object[]{dimNames.getDataAt(1), dimNames.getDataAt(0)}));
}
@Specialization(guards = "!isEmpty2D(vector)")
protected RVector transpose(RAbstractIntVector vector) {
return transposeInternal(vector, Transpose::innerLoopInt);
}
@Specialization(guards = "!isEmpty2D(vector)")
protected RVector transpose(RAbstractDoubleVector vector) {
return transposeInternal(vector, Transpose::innerLoopDouble);
}
@Specialization(guards = "!isEmpty2D(vector)")
protected RVector transpose(RAbstractStringVector vector) {
return transposeInternal(vector, Transpose::innerLoopString);
}
protected static boolean isEmpty2D(RAbstractVector vector) {
......
......@@ -52,6 +52,10 @@ public abstract class CopyAttributesNode extends RBaseNode {
this.copyAllAttributes = copyAllAttributes;
}
public static CopyAttributesNode createCopyAllAttributes() {
return CopyAttributesNodeGen.create(true);
}
public abstract RAbstractVector execute(RAbstractVector target, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength);
protected boolean containsMetadata(RAbstractVector vector, RAttributeProfiles attrProfiles) {
......
......@@ -55,6 +55,10 @@ public abstract class PutAttributeNode extends RBaseNode {
return PutAttributeNodeGen.create(RRuntime.DIM_ATTR_KEY);
}
public static PutAttributeNode createDimNames() {
return PutAttributeNodeGen.create(RRuntime.DIMNAMES_ATTR_KEY);
}
public abstract void execute(RAttributes attr, Object value);
protected boolean nameMatches(RAttributes attr, int index) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment