Skip to content
Snippets Groups Projects
Commit e480dc70 authored by stepan's avatar stepan Committed by Mick Jordan
Browse files

Fix: crossprod does not promote dimnames

parent 75bb45b3
Branches
No related tags found
No related merge requests found
......@@ -42,7 +42,7 @@ public abstract class Crossprod extends RBuiltinNode {
private void ensureMatMult() {
if (matMult == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
matMult = insert(MatMultNodeGen.create(null));
matMult = insert(MatMultNodeGen.create(/* promoteDimNames: */ false, null));
}
}
......
......@@ -50,6 +50,7 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
......@@ -60,6 +61,7 @@ public abstract class MatMult extends RBuiltinNode {
@Child private BinaryMapArithmeticFunctionNode mult = new BinaryMapArithmeticFunctionNode(BinaryArithmetic.MULTIPLY.create());
@Child private BinaryMapArithmeticFunctionNode add = new BinaryMapArithmeticFunctionNode(BinaryArithmetic.ADD.create());
private final boolean promoteDimNames;
private final BranchProfile errorProfile = BranchProfile.create();
private final LoopConditionProfile mainLoopProfile = LoopConditionProfile.createCountingProfile();
......@@ -68,18 +70,23 @@ public abstract class MatMult extends RBuiltinNode {
private final ConditionProfile notOneRow = ConditionProfile.createBinaryProfile();
private final ConditionProfile notOneColumn = ConditionProfile.createBinaryProfile();
@CompilationFinal private RAttributeProfiles aDimAttributeProfile;
@CompilationFinal private RAttributeProfiles bDimAttributeProfile;
@CompilationFinal private ConditionProfile noDimAttributes;
private final RAttributeProfiles aDimAttributeProfile = RAttributeProfiles.create();
private final RAttributeProfiles bDimAttributeProfile = RAttributeProfiles.create();
private final ConditionProfile noDimAttributes = ConditionProfile.createBinaryProfile();
protected abstract Object executeObject(Object a, Object b);
private final NACheck na;
public MatMult() {
public MatMult(boolean promoteDimNames) {
this.promoteDimNames = promoteDimNames;
this.na = NACheck.create();
}
public static MatMult create(RNode[] arguments) {
return MatMultNodeGen.create(true, arguments);
}
@Specialization(guards = "bothZeroDim(a, b)")
protected RDoubleVector both0Dim(RAbstractDoubleVector a, RAbstractDoubleVector b) {
int r = b.getDimensions()[1];
......@@ -206,16 +213,9 @@ public abstract class MatMult extends RBuiltinNode {
}
RDoubleVector resultVec = RDataFactory.createDoubleVector(result, complete, new int[]{aRows, bCols});
if (aDimAttributeProfile == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
aDimAttributeProfile = RAttributeProfiles.create();
bDimAttributeProfile = RAttributeProfiles.create();
noDimAttributes = ConditionProfile.createBinaryProfile();
}
RList aDimNames = a.getDimNames(aDimAttributeProfile);
RList bDimNames = b.getDimNames(bDimAttributeProfile);
if (noDimAttributes.profile(aDimNames == null && bDimNames == null)) {
if (!promoteDimNames || noDimAttributes.profile(aDimNames == null && bDimNames == null)) {
return resultVec;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment