Skip to content
Snippets Groups Projects
Commit 5b2369cb authored by Lukas Stadler's avatar Lukas Stadler
Browse files

implement cumsum/cumprod using VectorAccess

parent 94460377
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.Arrays;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
......@@ -32,16 +33,14 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
@RBuiltin(name = "cumprod", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE)
public abstract class CumProd extends RBuiltinNode.Arg1 {
private final NACheck na = NACheck.create();
@Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create();
@Child private BinaryArithmetic mul = BinaryArithmetic.MULTIPLY.createOperation();
static {
......@@ -55,63 +54,63 @@ public abstract class CumProd extends RBuiltinNode.Arg1 {
}
@Specialization
protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull rnull) {
protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull x) {
return RDataFactory.createEmptyDoubleVector();
}
@Specialization(guards = "emptyVec.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractComplexVector emptyVec) {
return RDataFactory.createComplexVector(new double[0], true, emptyVec.getNames());
@Specialization(guards = "xAccess.supports(x)")
protected RDoubleVector cumprodDouble(RAbstractDoubleVector x,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
double[] array = new double[xAccess.getLength(iter)];
double prev = 1;
while (xAccess.next(iter)) {
double value = xAccess.getDouble(iter);
if (xAccess.na.check(value)) {
Arrays.fill(array, iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
break;
}
if (xAccess.na.checkNAorNaN(value)) {
Arrays.fill(array, iter.getIndex(), array.length, Double.NaN);
break;
}
prev = mul.op(prev, value);
assert !RRuntime.isNA(prev) : "double multiplication should not introduce NAs";
array[iter.getIndex()] = prev;
}
return RDataFactory.createDoubleVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
}
}
@Specialization(guards = "emptyVec.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractDoubleVector emptyVec) {
return RDataFactory.createDoubleVector(new double[0], true, emptyVec.getNames());
@Specialization(replaces = "cumprodDouble")
protected RDoubleVector cumprodDoubleGeneric(RAbstractDoubleVector x) {
return cumprodDouble(x, x.slowPathAccess());
}
@Specialization
protected RDoubleVector cumprod(RAbstractDoubleVector arg) {
double[] array = new double[arg.getLength()];
na.enable(arg);
double prev = 1;
int i;
for (i = 0; i < arg.getLength(); i++) {
double value = arg.getDataAt(i);
if (na.check(value)) {
Arrays.fill(array, i, array.length, RRuntime.DOUBLE_NA);
break;
}
if (na.checkNAorNaN(value)) {
Arrays.fill(array, i, array.length, Double.NaN);
break;
@Specialization(guards = "xAccess.supports(x)")
protected RComplexVector cumprodComplex(RAbstractComplexVector x,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
double[] array = new double[xAccess.getLength(iter) * 2];
RComplex prev = RDataFactory.createComplex(1, 0);
while (xAccess.next(iter)) {
double real = xAccess.getComplexR(iter);
double imag = xAccess.getComplexI(iter);
if (xAccess.na.check(real, imag)) {
Arrays.fill(array, 2 * iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
break;
}
prev = mul.op(prev.getRealPart(), prev.getImaginaryPart(), real, imag);
assert !RRuntime.isNA(prev) : "complex multiplication should not introduce NAs";
array[iter.getIndex() * 2] = prev.getRealPart();
array[iter.getIndex() * 2 + 1] = prev.getImaginaryPart();
}
prev = mul.op(prev, value);
array[i] = prev;
return RDataFactory.createComplexVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
}
return RDataFactory.createDoubleVector(array, na.neverSeenNA(), getNamesNode.getNames(arg));
}
@Specialization
protected RComplexVector cumprod(RAbstractComplexVector arg) {
double[] array = new double[arg.getLength() * 2];
na.enable(arg);
RComplex prev = RDataFactory.createComplex(1, 0);
int i;
for (i = 0; i < arg.getLength(); i++) {
RComplex value = arg.getDataAt(i);
if (na.check(value)) {
break;
}
prev = mul.op(prev.getRealPart(), prev.getImaginaryPart(), value.getRealPart(), value.getImaginaryPart());
if (na.check(prev)) {
break;
}
array[i * 2] = prev.getRealPart();
array[i * 2 + 1] = prev.getImaginaryPart();
}
if (!na.neverSeenNA()) {
Arrays.fill(array, 2 * i, array.length, RRuntime.DOUBLE_NA);
}
return RDataFactory.createComplexVector(array, na.neverSeenNA(), getNamesNode.getNames(arg));
@Specialization(replaces = "cumprodComplex")
protected RComplexVector cumprodComplexGeneric(RAbstractComplexVector x) {
return cumprodComplex(x, x.slowPathAccess());
}
}
......@@ -36,6 +36,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.Arrays;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
......@@ -46,22 +47,20 @@ import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RComplexVector;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RIntSequence;
import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
@RBuiltin(name = "cumsum", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE)
public abstract class CumSum extends RBuiltinNode.Arg1 {
private final NACheck na = NACheck.create();
@Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create();
@Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation();
static {
......@@ -81,100 +80,107 @@ public abstract class CumSum extends RBuiltinNode.Arg1 {
}
@Specialization
protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull rnull) {
protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull x) {
return RDataFactory.createEmptyDoubleVector();
}
@Specialization(guards = "emptyVec.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractComplexVector emptyVec) {
return RDataFactory.createComplexVector(new double[0], true, emptyVec.getNames());
@Specialization(guards = "x.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractComplexVector x) {
return RDataFactory.createComplexVector(new double[0], true, getNamesNode.getNames(x));
}
@Specialization(guards = "emptyVec.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractDoubleVector emptyVec) {
return RDataFactory.createDoubleVector(new double[0], true, emptyVec.getNames());
@Specialization(guards = "x.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractDoubleVector x) {
return RDataFactory.createDoubleVector(new double[0], true, getNamesNode.getNames(x));
}
@Specialization(guards = "emptyVec.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractIntVector emptyVec) {
return RDataFactory.createIntVector(new int[0], true, emptyVec.getNames());
@Specialization(guards = "x.getLength()==0")
protected RAbstractVector cumEmpty(RAbstractIntVector x) {
return RDataFactory.createIntVector(new int[0], true, getNamesNode.getNames(x));
}
@Specialization
protected RIntVector cumsum(RIntSequence arg) {
int[] res = new int[arg.getLength()];
int current = arg.getStart();
int prev = 0;
na.enable(true);
for (int i = 0; i < arg.getLength(); i++) {
prev = add.op(prev, current);
if (na.check(prev)) {
Arrays.fill(res, i, res.length, RRuntime.INT_NA);
break;
@Specialization(guards = "xAccess.supports(x)")
protected RIntVector cumsumInt(RAbstractIntVector x,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
int[] array = new int[xAccess.getLength(iter)];
int prev = 0;
while (xAccess.next(iter)) {
int value = xAccess.getInt(iter);
if (xAccess.na.check(value)) {
Arrays.fill(array, iter.getIndex(), array.length, RRuntime.INT_NA);
break;
}
prev = add.op(prev, value);
// integer addition can introduce NAs
if (add.introducesNA() && RRuntime.isNA(prev)) {
Arrays.fill(array, iter.getIndex(), array.length, RRuntime.INT_NA);
break;
}
array[iter.getIndex()] = prev;
}
current += arg.getStride();
res[i] = prev;
return RDataFactory.createIntVector(array, xAccess.na.neverSeenNA() && !add.introducesNA(), getNamesNode.getNames(x));
}
return RDataFactory.createIntVector(res, na.neverSeenNA(), getNamesNode.getNames(arg));
}
@Specialization
protected RDoubleVector cumsum(RAbstractDoubleVector arg) {
double[] res = new double[arg.getLength()];
double prev = 0.0;
na.enable(true);
for (int i = 0; i < arg.getLength(); i++) {
double value = arg.getDataAt(i);
// cumsum behaves different than cumprod for NaNs:
if (na.check(value)) {
Arrays.fill(res, i, res.length, RRuntime.DOUBLE_NA);
break;
} else if (na.checkNAorNaN(value)) {
Arrays.fill(res, i, res.length, Double.NaN);
break;
}
prev = add.op(prev, value);
res[i] = prev;
}
return RDataFactory.createDoubleVector(res, na.neverSeenNA(), getNamesNode.getNames(arg));
@Specialization(replaces = "cumsumInt")
protected RIntVector cumsumIntGeneric(RAbstractIntVector x) {
return cumsumInt(x, x.slowPathAccess());
}
@Specialization
protected RIntVector cumsum(RAbstractIntVector arg) {
int[] res = new int[arg.getLength()];
int prev = 0;
int i;
na.enable(true);
for (i = 0; i < arg.getLength(); i++) {
if (na.check(arg.getDataAt(i))) {
break;
}
prev = add.op(prev, arg.getDataAt(i));
if (na.check(prev)) {
break;
@Specialization(guards = "xAccess.supports(x)")
protected RDoubleVector cumsumDouble(RAbstractDoubleVector x,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
double[] array = new double[xAccess.getLength(iter)];
double prev = 0;
while (xAccess.next(iter)) {
double value = xAccess.getDouble(iter);
if (xAccess.na.check(value)) {
Arrays.fill(array, iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
break;
}
if (xAccess.na.checkNAorNaN(value)) {
Arrays.fill(array, iter.getIndex(), array.length, Double.NaN);
break;
}
prev = add.op(prev, value);
assert !RRuntime.isNA(prev) : "double addition should not introduce NAs";
array[iter.getIndex()] = prev;
}
res[i] = prev;
return RDataFactory.createDoubleVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
}
if (!na.neverSeenNA()) {
Arrays.fill(res, i, res.length, RRuntime.INT_NA);
}
return RDataFactory.createIntVector(res, na.neverSeenNA(), getNamesNode.getNames(arg));
}
@Specialization
protected RComplexVector cumsum(RAbstractComplexVector arg) {
double[] res = new double[arg.getLength() * 2];
RComplex prev = RDataFactory.createComplex(0.0, 0.0);
na.enable(true);
for (int i = 0; i < arg.getLength(); i++) {
prev = add.op(prev.getRealPart(), prev.getImaginaryPart(), arg.getDataAt(i).getRealPart(), arg.getDataAt(i).getImaginaryPart());
if (na.check(arg.getDataAt(i))) {
Arrays.fill(res, 2 * i, res.length, RRuntime.DOUBLE_NA);
break;
@Specialization(replaces = "cumsumDouble")
protected RDoubleVector cumsumDoubleGeneric(RAbstractDoubleVector x) {
return cumsumDouble(x, x.slowPathAccess());
}
@Specialization(guards = "xAccess.supports(x)")
protected RComplexVector cumsumComplex(RAbstractComplexVector x,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
double[] array = new double[xAccess.getLength(iter) * 2];
RComplex prev = RDataFactory.createComplex(0, 0);
while (xAccess.next(iter)) {
double real = xAccess.getComplexR(iter);
double imag = xAccess.getComplexI(iter);
if (xAccess.na.check(real, imag)) {
Arrays.fill(array, 2 * iter.getIndex(), array.length, RRuntime.DOUBLE_NA);
break;
}
prev = add.op(prev.getRealPart(), prev.getImaginaryPart(), real, imag);
assert !RRuntime.isNA(prev) : "complex addition should not introduce NAs";
array[iter.getIndex() * 2] = prev.getRealPart();
array[iter.getIndex() * 2 + 1] = prev.getImaginaryPart();
}
res[2 * i] = prev.getRealPart();
res[2 * i + 1] = prev.getImaginaryPart();
return RDataFactory.createComplexVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x));
}
return RDataFactory.createComplexVector(res, na.neverSeenNA(), getNamesNode.getNames(arg));
}
@Specialization(replaces = "cumsumComplex")
protected RComplexVector cumsumComplexGeneric(RAbstractComplexVector x) {
return cumsumComplex(x, x.slowPathAccess());
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment