Skip to content
Snippets Groups Projects
Commit 5b2369cb authored by Lukas Stadler's avatar Lukas Stadler
Browse files

implement cumsum/cumprod using VectorAccess

parent 94460377
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; ...@@ -19,6 +19,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.Arrays; import java.util.Arrays;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
...@@ -32,16 +33,14 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector; ...@@ -32,16 +33,14 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
@RBuiltin(name = "cumprod", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) @RBuiltin(name = "cumprod", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE)
public abstract class CumProd extends RBuiltinNode.Arg1 { public abstract class CumProd extends RBuiltinNode.Arg1 {
private final NACheck na = NACheck.create();
@Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create();
@Child private BinaryArithmetic mul = BinaryArithmetic.MULTIPLY.createOperation(); @Child private BinaryArithmetic mul = BinaryArithmetic.MULTIPLY.createOperation();
static { static {
...@@ -55,63 +54,63 @@ public abstract class CumProd extends RBuiltinNode.Arg1 { ...@@ -55,63 +54,63 @@ public abstract class CumProd extends RBuiltinNode.Arg1 {
} }
@Specialization @Specialization
protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull rnull) { protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull x) {
return RDataFactory.createEmptyDoubleVector(); return RDataFactory.createEmptyDoubleVector();
} }
@Specialization(guards = "emptyVec.getLength()==0") @Specialization(guards = "xAccess.supports(x)")
protected RAbstractVector cumEmpty(RAbstractComplexVector emptyVec) { protected RDoubleVector cumprodDouble(RAbstractDoubleVector x,
return RDataFactory.createComplexVector(new double[0], true, emptyVec.getNames()); @Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
double[] array = new double[xAccess.getLength(iter)];
double prev = 1;
while (xAccess.next(iter)) {
double value = xAccess.getDouble(iter);
if (xAccess.na.check(value)) {
Arrays.fill(array, iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
break;
}
if (xAccess.na.checkNAorNaN(value)) {
Arrays.fill(array, iter.getIndex(), array.length, Double.NaN);
break;
}
prev = mul.op(prev, value);
assert !RRuntime.isNA(prev) : "double multiplication should not introduce NAs";
array[iter.getIndex()] = prev;
}
return RDataFactory.createDoubleVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
}
} }
@Specialization(guards = "emptyVec.getLength()==0") @Specialization(replaces = "cumprodDouble")
protected RAbstractVector cumEmpty(RAbstractDoubleVector emptyVec) { protected RDoubleVector cumprodDoubleGeneric(RAbstractDoubleVector x) {
return RDataFactory.createDoubleVector(new double[0], true, emptyVec.getNames()); return cumprodDouble(x, x.slowPathAccess());
} }
@Specialization @Specialization(guards = "xAccess.supports(x)")
protected RDoubleVector cumprod(RAbstractDoubleVector arg) { protected RComplexVector cumprodComplex(RAbstractComplexVector x,
double[] array = new double[arg.getLength()]; @Cached("x.access()") VectorAccess xAccess) {
na.enable(arg); try (SequentialIterator iter = xAccess.access(x)) {
double prev = 1; double[] array = new double[xAccess.getLength(iter) * 2];
int i; RComplex prev = RDataFactory.createComplex(1, 0);
for (i = 0; i < arg.getLength(); i++) { while (xAccess.next(iter)) {
double value = arg.getDataAt(i); double real = xAccess.getComplexR(iter);
if (na.check(value)) { double imag = xAccess.getComplexI(iter);
Arrays.fill(array, i, array.length, RRuntime.DOUBLE_NA); if (xAccess.na.check(real, imag)) {
break; Arrays.fill(array, 2 * iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
} break;
if (na.checkNAorNaN(value)) { }
Arrays.fill(array, i, array.length, Double.NaN); prev = mul.op(prev.getRealPart(), prev.getImaginaryPart(), real, imag);
break; assert !RRuntime.isNA(prev) : "complex multiplication should not introduce NAs";
array[iter.getIndex() * 2] = prev.getRealPart();
array[iter.getIndex() * 2 + 1] = prev.getImaginaryPart();
} }
prev = mul.op(prev, value); return RDataFactory.createComplexVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
array[i] = prev;
} }
return RDataFactory.createDoubleVector(array, na.neverSeenNA(), getNamesNode.getNames(arg));
} }
@Specialization @Specialization(replaces = "cumprodComplex")
protected RComplexVector cumprod(RAbstractComplexVector arg) { protected RComplexVector cumprodComplexGeneric(RAbstractComplexVector x) {
double[] array = new double[arg.getLength() * 2]; return cumprodComplex(x, x.slowPathAccess());
na.enable(arg);
RComplex prev = RDataFactory.createComplex(1, 0);
int i;
for (i = 0; i < arg.getLength(); i++) {
RComplex value = arg.getDataAt(i);
if (na.check(value)) {
break;
}
prev = mul.op(prev.getRealPart(), prev.getImaginaryPart(), value.getRealPart(), value.getImaginaryPart());
if (na.check(prev)) {
break;
}
array[i * 2] = prev.getRealPart();
array[i * 2 + 1] = prev.getImaginaryPart();
}
if (!na.neverSeenNA()) {
Arrays.fill(array, 2 * i, array.length, RRuntime.DOUBLE_NA);
}
return RDataFactory.createComplexVector(array, na.neverSeenNA(), getNamesNode.getNames(arg));
} }
} }
...@@ -36,6 +36,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; ...@@ -36,6 +36,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.Arrays; import java.util.Arrays;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
...@@ -46,22 +47,20 @@ import com.oracle.truffle.r.runtime.data.RComplex; ...@@ -46,22 +47,20 @@ import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RComplexVector;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RIntSequence;
import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
@RBuiltin(name = "cumsum", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) @RBuiltin(name = "cumsum", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE)
public abstract class CumSum extends RBuiltinNode.Arg1 { public abstract class CumSum extends RBuiltinNode.Arg1 {
private final NACheck na = NACheck.create();
@Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create();
@Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation(); @Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation();
static { static {
...@@ -81,100 +80,107 @@ public abstract class CumSum extends RBuiltinNode.Arg1 { ...@@ -81,100 +80,107 @@ public abstract class CumSum extends RBuiltinNode.Arg1 {
} }
@Specialization @Specialization
protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull rnull) { protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull x) {
return RDataFactory.createEmptyDoubleVector(); return RDataFactory.createEmptyDoubleVector();
} }
@Specialization(guards = "emptyVec.getLength()==0") @Specialization(guards = "x.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractComplexVector emptyVec) { protected RAbstractVector cumEmpty(RAbstractComplexVector x) {
return RDataFactory.createComplexVector(new double[0], true, emptyVec.getNames()); return RDataFactory.createComplexVector(new double[0], true, getNamesNode.getNames(x));
} }
@Specialization(guards = "emptyVec.getLength()==0") @Specialization(guards = "x.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractDoubleVector emptyVec) { protected RAbstractVector cumEmpty(RAbstractDoubleVector x) {
return RDataFactory.createDoubleVector(new double[0], true, emptyVec.getNames()); return RDataFactory.createDoubleVector(new double[0], true, getNamesNode.getNames(x));
} }
@Specialization(guards = "emptyVec.getLength()==0") @Specialization(guards = "x.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractIntVector emptyVec) { protected RAbstractVector cumEmpty(RAbstractIntVector x) {
return RDataFactory.createIntVector(new int[0], true, emptyVec.getNames()); return RDataFactory.createIntVector(new int[0], true, getNamesNode.getNames(x));
} }
@Specialization @Specialization(guards = "xAccess.supports(x)")
protected RIntVector cumsum(RIntSequence arg) { protected RIntVector cumsumInt(RAbstractIntVector x,
int[] res = new int[arg.getLength()]; @Cached("x.access()") VectorAccess xAccess) {
int current = arg.getStart(); try (SequentialIterator iter = xAccess.access(x)) {
int prev = 0; int[] array = new int[xAccess.getLength(iter)];
na.enable(true); int prev = 0;
for (int i = 0; i < arg.getLength(); i++) { while (xAccess.next(iter)) {
prev = add.op(prev, current); int value = xAccess.getInt(iter);
if (na.check(prev)) { if (xAccess.na.check(value)) {
Arrays.fill(res, i, res.length, RRuntime.INT_NA); Arrays.fill(array, iter.getIndex(), array.length, RRuntime.INT_NA);
break; break;
}
prev = add.op(prev, value);
// integer addition can introduce NAs
if (add.introducesNA() && RRuntime.isNA(prev)) {
Arrays.fill(array, iter.getIndex(), array.length, RRuntime.INT_NA);
break;
}
array[iter.getIndex()] = prev;
} }
current += arg.getStride(); return RDataFactory.createIntVector(array, xAccess.na.neverSeenNA() && !add.introducesNA(), getNamesNode.getNames(x));
res[i] = prev;
} }
return RDataFactory.createIntVector(res, na.neverSeenNA(), getNamesNode.getNames(arg));
} }
@Specialization @Specialization(replaces = "cumsumInt")
protected RDoubleVector cumsum(RAbstractDoubleVector arg) { protected RIntVector cumsumIntGeneric(RAbstractIntVector x) {
double[] res = new double[arg.getLength()]; return cumsumInt(x, x.slowPathAccess());
double prev = 0.0;
na.enable(true);
for (int i = 0; i < arg.getLength(); i++) {
double value = arg.getDataAt(i);
// cumsum behaves different than cumprod for NaNs:
if (na.check(value)) {
Arrays.fill(res, i, res.length, RRuntime.DOUBLE_NA);
break;
} else if (na.checkNAorNaN(value)) {
Arrays.fill(res, i, res.length, Double.NaN);
break;
}
prev = add.op(prev, value);
res[i] = prev;
}
return RDataFactory.createDoubleVector(res, na.neverSeenNA(), getNamesNode.getNames(arg));
} }
@Specialization @Specialization(guards = "xAccess.supports(x)")
protected RIntVector cumsum(RAbstractIntVector arg) { protected RDoubleVector cumsumDouble(RAbstractDoubleVector x,
int[] res = new int[arg.getLength()]; @Cached("x.access()") VectorAccess xAccess) {
int prev = 0; try (SequentialIterator iter = xAccess.access(x)) {
int i; double[] array = new double[xAccess.getLength(iter)];
na.enable(true); double prev = 0;
for (i = 0; i < arg.getLength(); i++) { while (xAccess.next(iter)) {
if (na.check(arg.getDataAt(i))) { double value = xAccess.getDouble(iter);
break; if (xAccess.na.check(value)) {
} Arrays.fill(array, iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
prev = add.op(prev, arg.getDataAt(i)); break;
if (na.check(prev)) { }
break; if (xAccess.na.checkNAorNaN(value)) {
Arrays.fill(array, iter.getIndex(), array.length, Double.NaN);
break;
}
prev = add.op(prev, value);
assert !RRuntime.isNA(prev) : "double addition should not introduce NAs";
array[iter.getIndex()] = prev;
} }
res[i] = prev; return RDataFactory.createDoubleVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
} }
if (!na.neverSeenNA()) {
Arrays.fill(res, i, res.length, RRuntime.INT_NA);
}
return RDataFactory.createIntVector(res, na.neverSeenNA(), getNamesNode.getNames(arg));
} }
@Specialization @Specialization(replaces = "cumsumDouble")
protected RComplexVector cumsum(RAbstractComplexVector arg) { protected RDoubleVector cumsumDoubleGeneric(RAbstractDoubleVector x) {
double[] res = new double[arg.getLength() * 2]; return cumsumDouble(x, x.slowPathAccess());
RComplex prev = RDataFactory.createComplex(0.0, 0.0); }
na.enable(true);
for (int i = 0; i < arg.getLength(); i++) { @Specialization(guards = "xAccess.supports(x)")
prev = add.op(prev.getRealPart(), prev.getImaginaryPart(), arg.getDataAt(i).getRealPart(), arg.getDataAt(i).getImaginaryPart()); protected RComplexVector cumsumComplex(RAbstractComplexVector x,
if (na.check(arg.getDataAt(i))) { @Cached("x.access()") VectorAccess xAccess) {
Arrays.fill(res, 2 * i, res.length, RRuntime.DOUBLE_NA); try (SequentialIterator iter = xAccess.access(x)) {
break; double[] array = new double[xAccess.getLength(iter) * 2];
RComplex prev = RDataFactory.createComplex(0, 0);
while (xAccess.next(iter)) {
double real = xAccess.getComplexR(iter);
double imag = xAccess.getComplexI(iter);
if (xAccess.na.check(real, imag)) {
Arrays.fill(array, 2 * iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
break;
}
prev = add.op(prev.getRealPart(), prev.getImaginaryPart(), real, imag);
assert !RRuntime.isNA(prev) : "complex addition should not introduce NAs";
array[iter.getIndex() * 2] = prev.getRealPart();
array[iter.getIndex() * 2 + 1] = prev.getImaginaryPart();
} }
res[2 * i] = prev.getRealPart(); return RDataFactory.createComplexVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
res[2 * i + 1] = prev.getImaginaryPart();
} }
return RDataFactory.createComplexVector(res, na.neverSeenNA(), getNamesNode.getNames(arg)); }
@Specialization(replaces = "cumsumComplex")
protected RComplexVector cumsumComplexGeneric(RAbstractComplexVector x) {
return cumsumComplex(x, x.slowPathAccess());
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment