Skip to content
Snippets Groups Projects
Commit 5def5142 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

better specializations for vector/scalar operations in BinaryArithmeticNode

parent 369e8543
Branches
No related tags found
No related merge requests found
......@@ -52,9 +52,11 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode {
private final NACheck rightNACheck;
private final NACheck resultNACheck;
private final ConditionProfile noDimensionsProfile = ConditionProfile.createBinaryProfile();
private final ConditionProfile emptyVector = ConditionProfile.createBinaryProfile();
private final BranchProfile hasAttributesProfile = BranchProfile.create();
private final BranchProfile warningProfile = BranchProfile.create();
private final ConditionProfile leftLongerProfile = ConditionProfile.createBinaryProfile();
public BinaryArithmeticNode(BinaryArithmeticFactory factory, UnaryArithmeticFactory unaryFactory) {
this.arithmetic = factory.create();
......@@ -383,8 +385,22 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode {
return performArithmeticComplexEnableNACheck(left, right);
}
protected static boolean differentDimensions(RAbstractVector left, RAbstractVector right) {
return BinaryBooleanNode.differentDimensions(left, right);
protected boolean differentDimensions(RAbstractVector left, RAbstractVector right) {
if (noDimensionsProfile.profile(!left.hasDimensions() || !right.hasDimensions())) {
return false;
}
int[] leftDimensions = left.getDimensions();
int[] rightDimensions = right.getDimensions();
assert (leftDimensions != null && rightDimensions != null);
if (leftDimensions.length != rightDimensions.length) {
return true;
}
for (int i = 0; i < leftDimensions.length; i++) {
if (leftDimensions[i] != rightDimensions[i]) {
return true;
}
}
return false;
}
@SuppressWarnings("unused")
......@@ -395,6 +411,36 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode {
// int vector and vectors
public boolean isTemporary(RIntVector vector) {
return vector.isTemporary();
}
@Specialization(guards = {"supportsIntResult", "isTemporary(arguments[0])"})
protected RIntVector doIntVectorScalar(RIntVector left, int right) {
leftNACheck.enable(left);
rightNACheck.enable(right);
return performOpVectorScalar(left, RIntVector::getDataWithoutCopying, RDataFactory::createEmptyIntVector, (array, i) -> array[i] = performArithmetic(array[i], right));
}
@Specialization(guards = {"supportsIntResult", "isTemporary(arguments[1])"})
protected RIntVector doIntVectorScalar(int left, RIntVector right) {
leftNACheck.enable(left);
rightNACheck.enable(right);
return performOpVectorScalar(right, RIntVector::getDataWithoutCopying, RDataFactory::createEmptyIntVector, (array, i) -> array[i] = performArithmetic(left, array[i]));
}
@Specialization(guards = "!supportsIntResult")
protected RDoubleVector doIntVectorScalar(RIntVector left, double right) {
return performOpDifferentLength(left, double[]::new, RDataFactory::createEmptyDoubleVector, RDataFactory::createDoubleVector,
(array, i) -> array[i] = performArithmeticDouble(leftNACheck.convertIntToDouble(left.getDataAt(i)), right));
}
@Specialization(guards = "!supportsIntResult")
protected RDoubleVector doIntVectorScalar(double left, RIntVector right) {
return performOpDifferentLength(right, double[]::new, RDataFactory::createEmptyDoubleVector, RDataFactory::createDoubleVector,
(array, i) -> array[i] = performArithmeticDouble(left, rightNACheck.convertIntToDouble(right.getDataAt(i))));
}
@Specialization(guards = {"!areSameLength", "supportsIntResult", "!differentDimensions"})
protected RIntVector doIntVectorDifferentLength(RAbstractIntVector left, RAbstractIntVector right) {
return performIntVectorOpDifferentLength(left, right);
......@@ -502,37 +548,17 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode {
}
@Specialization(guards = "isTemporary(arguments[0])")
protected RDoubleVector doDoubleVector(RDoubleVector left, double right) {
int length = left.getLength();
if (emptyVector.profile(length == 0)) {
return left;
}
protected RDoubleVector doDoubleVectorScalar(RDoubleVector left, double right) {
leftNACheck.enable(left);
rightNACheck.enable(right);
resultNACheck.enable(arithmetic.introducesNA());
double[] result = left.getDataWithoutCopying();
for (int i = 0; i < length; ++i) {
result[i] = performArithmeticDouble(result[i], right);
}
left.setComplete(isComplete());
return left;
return performOpVectorScalar(left, RDoubleVector::getDataWithoutCopying, RDataFactory::createEmptyDoubleVector, (array, i) -> array[i] = performArithmeticDouble(array[i], right));
}
@Specialization(guards = "isTemporary(arguments[1])")
protected RDoubleVector doDoubleVector(double left, RDoubleVector right) {
int length = right.getLength();
if (emptyVector.profile(length == 0)) {
return right;
}
protected RDoubleVector doDoubleVectorScalar(double left, RDoubleVector right) {
leftNACheck.enable(left);
rightNACheck.enable(right);
resultNACheck.enable(arithmetic.introducesNA());
double[] result = right.getDataWithoutCopying();
for (int i = 0; i < length; ++i) {
result[i] = performArithmeticDouble(left, result[i]);
}
right.setComplete(isComplete());
return right;
return performOpVectorScalar(right, RDoubleVector::getDataWithoutCopying, RDataFactory::createEmptyDoubleVector, (array, i) -> array[i] = performArithmeticDouble(left, array[i]));
}
@Specialization
......@@ -762,24 +788,52 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode {
Supplier<ResultT> emptyConstructor, BiFunction<ArrayT, Boolean, ResultT> resultFunction, DifferentOpFunction<ArrayT> op) {
int leftLength = left.getLength();
int rightLength = right.getLength();
int length = Math.max(leftLength, rightLength);
if (emptyVector.profile(leftLength == 0 || rightLength == 0)) {
return emptyConstructor.get();
}
int length = Math.max(leftLength, rightLength);
leftNACheck.enable(left);
rightNACheck.enable(right);
resultNACheck.enable(arithmetic.introducesNA());
ArrayT result = arrayConstructor.apply(length);
int j = 0;
int k = 0;
for (int i = 0; i < length; ++i) {
op.apply(result, i, j, k);
j = Utils.incMod(j, leftLength);
k = Utils.incMod(k, rightLength);
boolean notMultiple = false;
if (leftLongerProfile.profile(leftLength > rightLength)) {
if (leftLength % rightLength != 0) {
warningProfile.enter();
notMultiple = true;
} else {
while (j < leftLength) {
k = 0;
while (k < rightLength) {
op.apply(result, j, j, k);
j++;
k++;
}
}
}
} else {
if (rightLength % leftLength != 0) {
warningProfile.enter();
notMultiple = true;
} else {
while (k < rightLength) {
j = 0;
while (j < leftLength) {
op.apply(result, k, j, k);
j++;
k++;
}
}
}
}
boolean notMultiple = j != 0 || k != 0;
if (notMultiple) {
warningProfile.enter();
for (int i = 0; i < length; ++i) {
op.apply(result, i, j, k);
j = Utils.incMod(j, leftLength);
k = Utils.incMod(k, rightLength);
}
RError.warning(RError.Message.LENGTH_NOT_MULTI);
}
ResultT ret = resultFunction.apply(result, isComplete());
......@@ -799,10 +853,26 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode {
op.apply(result, i);
}
ResultT ret = resultFunction.apply(result, isComplete());
copyAttributes(ret, source);
if (ret != source) {
copyAttributes(ret, source);
}
return ret;
}
private <ParamT extends RVector, ArrayT> ParamT performOpVectorScalar(ParamT source, Function<ParamT, ArrayT> arrayConstructor, Supplier<ParamT> emptyConstructor, SameOpFunction<ArrayT> op) {
int length = source.getLength();
if (emptyVector.profile(length == 0)) {
return emptyConstructor.get();
}
resultNACheck.enable(arithmetic.introducesNA());
ArrayT result = arrayConstructor.apply(source);
for (int i = 0; i < length; ++i) {
op.apply(result, i);
}
source.setComplete(isComplete());
return source;
}
private RComplexVector performComplexVectorOpDifferentLength(RAbstractComplexVector left, RAbstractComplexVector right) {
return performOpDifferentLength(left, right, len -> new double[len << 1], RDataFactory::createEmptyComplexVector, RDataFactory::createComplexVector, (array, i, j, k) -> {
RComplex result = performArithmeticComplex(left.getDataAt(j), right.getDataAt(k));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment