From 5b2369cbb3e264eee61f49343ef311929593d413 Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Thu, 16 Nov 2017 12:02:37 +0100 Subject: [PATCH] implement cumsum/cumprod using VectorAccess --- .../truffle/r/nodes/builtin/base/CumProd.java | 101 ++++++----- .../truffle/r/nodes/builtin/base/CumSum.java | 162 +++++++++--------- 2 files changed, 134 insertions(+), 129 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumProd.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumProd.java index 69767d4e3d..c89d7c817a 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumProd.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumProd.java @@ -19,6 +19,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import java.util.Arrays; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; @@ -32,16 +33,14 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; -import com.oracle.truffle.r.runtime.ops.na.NACheck; @RBuiltin(name = "cumprod", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) public abstract class CumProd extends RBuiltinNode.Arg1 { - private final NACheck na = NACheck.create(); @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); - @Child private BinaryArithmetic mul = BinaryArithmetic.MULTIPLY.createOperation(); static { @@ -55,63 +54,63 @@ public abstract class CumProd extends RBuiltinNode.Arg1 { } @Specialization - protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull rnull) { + protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull x) { return RDataFactory.createEmptyDoubleVector(); } - @Specialization(guards = "emptyVec.getLength()==0") - protected RAbstractVector cumEmpty(RAbstractComplexVector emptyVec) { - return RDataFactory.createComplexVector(new double[0], true, emptyVec.getNames()); + @Specialization(guards = "xAccess.supports(x)") + protected RDoubleVector cumprodDouble(RAbstractDoubleVector x, + @Cached("x.access()") VectorAccess xAccess) { + try (SequentialIterator iter = xAccess.access(x)) { + double[] array = new double[xAccess.getLength(iter)]; + double prev = 1; + while (xAccess.next(iter)) { + double value = xAccess.getDouble(iter); + if (xAccess.na.check(value)) { + Arrays.fill(array, iter.getIndex(), array.length, RRuntime.DOUBLE_NA); + break; + } + if (xAccess.na.checkNAorNaN(value)) { + Arrays.fill(array, iter.getIndex(), array.length, Double.NaN); + break; + } + prev = mul.op(prev, value); + assert !RRuntime.isNA(prev) : "double multiplication should not introduce NAs"; + array[iter.getIndex()] = prev; + } + return RDataFactory.createDoubleVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x)); + } } - @Specialization(guards = "emptyVec.getLength()==0") - protected RAbstractVector cumEmpty(RAbstractDoubleVector emptyVec) { - return RDataFactory.createDoubleVector(new double[0], true, emptyVec.getNames()); + @Specialization(replaces = "cumprodDouble") + protected RDoubleVector cumprodDoubleGeneric(RAbstractDoubleVector x) { + return cumprodDouble(x, x.slowPathAccess()); } - @Specialization - protected RDoubleVector cumprod(RAbstractDoubleVector arg) { - double[] array = new double[arg.getLength()]; - na.enable(arg); - double prev = 1; - int i; - for (i = 0; i < arg.getLength(); i++) { - double value = arg.getDataAt(i); - if (na.check(value)) { - Arrays.fill(array, i, array.length, RRuntime.DOUBLE_NA); - break; - } - if (na.checkNAorNaN(value)) { - Arrays.fill(array, i, array.length, Double.NaN); - break; + @Specialization(guards = "xAccess.supports(x)") + protected RComplexVector cumprodComplex(RAbstractComplexVector x, + @Cached("x.access()") VectorAccess xAccess) { + try (SequentialIterator iter = xAccess.access(x)) { + double[] array = new double[xAccess.getLength(iter) * 2]; + RComplex prev = RDataFactory.createComplex(1, 0); + while (xAccess.next(iter)) { + double real = xAccess.getComplexR(iter); + double imag = xAccess.getComplexI(iter); + if (xAccess.na.check(real, imag)) { + Arrays.fill(array, 2 * iter.getIndex(), array.length, RRuntime.DOUBLE_NA); + break; + } + prev = mul.op(prev.getRealPart(), prev.getImaginaryPart(), real, imag); + assert !RRuntime.isNA(prev) : "complex multiplication should not introduce NAs"; + array[iter.getIndex() * 2] = prev.getRealPart(); + array[iter.getIndex() * 2 + 1] = prev.getImaginaryPart(); } - prev = mul.op(prev, value); - array[i] = prev; + return RDataFactory.createComplexVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x)); } - return RDataFactory.createDoubleVector(array, na.neverSeenNA(), getNamesNode.getNames(arg)); } - @Specialization - protected RComplexVector cumprod(RAbstractComplexVector arg) { - double[] array = new double[arg.getLength() * 2]; - na.enable(arg); - RComplex prev = RDataFactory.createComplex(1, 0); - int i; - for (i = 0; i < arg.getLength(); i++) { - RComplex value = arg.getDataAt(i); - if (na.check(value)) { - break; - } - prev = mul.op(prev.getRealPart(), prev.getImaginaryPart(), value.getRealPart(), value.getImaginaryPart()); - if (na.check(prev)) { - break; - } - array[i * 2] = prev.getRealPart(); - array[i * 2 + 1] = prev.getImaginaryPart(); - } - if (!na.neverSeenNA()) { - Arrays.fill(array, 2 * i, array.length, RRuntime.DOUBLE_NA); - } - return RDataFactory.createComplexVector(array, na.neverSeenNA(), getNamesNode.getNames(arg)); + @Specialization(replaces = "cumprodComplex") + protected RComplexVector cumprodComplexGeneric(RAbstractComplexVector x) { + return cumprodComplex(x, x.slowPathAccess()); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumSum.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumSum.java index f0b47f2e76..a84d226b6a 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumSum.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/CumSum.java @@ -36,6 +36,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import java.util.Arrays; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; @@ -46,22 +47,20 @@ import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; -import com.oracle.truffle.r.runtime.data.RIntSequence; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; -import com.oracle.truffle.r.runtime.ops.na.NACheck; @RBuiltin(name = "cumsum", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = MATH_GROUP_GENERIC, behavior = PURE) public abstract class CumSum extends RBuiltinNode.Arg1 { - private final NACheck na = NACheck.create(); @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); - @Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation(); static { @@ -81,100 +80,107 @@ public abstract class CumSum extends RBuiltinNode.Arg1 { } @Specialization - protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull rnull) { + protected RDoubleVector cumNull(@SuppressWarnings("unused") RNull x) { return RDataFactory.createEmptyDoubleVector(); } - @Specialization(guards = "emptyVec.getLength()==0") - protected RAbstractVector cumEmpty(RAbstractComplexVector emptyVec) { - return RDataFactory.createComplexVector(new double[0], true, emptyVec.getNames()); + @Specialization(guards = "x.getLength()==0") + protected RAbstractVector cumEmpty(RAbstractComplexVector x) { + return RDataFactory.createComplexVector(new double[0], true, getNamesNode.getNames(x)); } - @Specialization(guards = "emptyVec.getLength()==0") - protected RAbstractVector cumEmpty(RAbstractDoubleVector emptyVec) { - return RDataFactory.createDoubleVector(new double[0], true, emptyVec.getNames()); + @Specialization(guards = "x.getLength()==0") + protected RAbstractVector cumEmpty(RAbstractDoubleVector x) { + return RDataFactory.createDoubleVector(new double[0], true, getNamesNode.getNames(x)); } - @Specialization(guards = "emptyVec.getLength()==0") - protected RAbstractVector cumEmpty(RAbstractIntVector emptyVec) { - return RDataFactory.createIntVector(new int[0], true, emptyVec.getNames()); + @Specialization(guards = "x.getLength()==0") + protected RAbstractVector cumEmpty(RAbstractIntVector x) { + return RDataFactory.createIntVector(new int[0], true, getNamesNode.getNames(x)); } - @Specialization - protected RIntVector cumsum(RIntSequence arg) { - int[] res = new int[arg.getLength()]; - int current = arg.getStart(); - int prev = 0; - na.enable(true); - for (int i = 0; i < arg.getLength(); i++) { - prev = add.op(prev, current); - if (na.check(prev)) { - Arrays.fill(res, i, res.length, RRuntime.INT_NA); - break; + @Specialization(guards = "xAccess.supports(x)") + protected RIntVector cumsumInt(RAbstractIntVector x, + @Cached("x.access()") VectorAccess xAccess) { + try (SequentialIterator iter = xAccess.access(x)) { + int[] array = new int[xAccess.getLength(iter)]; + int prev = 0; + while (xAccess.next(iter)) { + int value = xAccess.getInt(iter); + if (xAccess.na.check(value)) { + Arrays.fill(array, iter.getIndex(), array.length, RRuntime.INT_NA); + break; + } + prev = add.op(prev, value); + // integer addition can introduce NAs + if (add.introducesNA() && RRuntime.isNA(prev)) { + Arrays.fill(array, iter.getIndex(), array.length, RRuntime.INT_NA); + break; + } + array[iter.getIndex()] = prev; } - current += arg.getStride(); - res[i] = prev; + return RDataFactory.createIntVector(array, xAccess.na.neverSeenNA() && !add.introducesNA(), getNamesNode.getNames(x)); } - return RDataFactory.createIntVector(res, na.neverSeenNA(), getNamesNode.getNames(arg)); } - @Specialization - protected RDoubleVector cumsum(RAbstractDoubleVector arg) { - double[] res = new double[arg.getLength()]; - double prev = 0.0; - na.enable(true); - for (int i = 0; i < arg.getLength(); i++) { - double value = arg.getDataAt(i); - // cumsum behaves different than cumprod for NaNs: - if (na.check(value)) { - Arrays.fill(res, i, res.length, RRuntime.DOUBLE_NA); - break; - } else if (na.checkNAorNaN(value)) { - Arrays.fill(res, i, res.length, Double.NaN); - break; - } - prev = add.op(prev, value); - res[i] = prev; - } - return RDataFactory.createDoubleVector(res, na.neverSeenNA(), getNamesNode.getNames(arg)); + @Specialization(replaces = "cumsumInt") + protected RIntVector cumsumIntGeneric(RAbstractIntVector x) { + return cumsumInt(x, x.slowPathAccess()); } - @Specialization - protected RIntVector cumsum(RAbstractIntVector arg) { - int[] res = new int[arg.getLength()]; - int prev = 0; - int i; - na.enable(true); - for (i = 0; i < arg.getLength(); i++) { - if (na.check(arg.getDataAt(i))) { - break; - } - prev = add.op(prev, arg.getDataAt(i)); - if (na.check(prev)) { - break; + @Specialization(guards = "xAccess.supports(x)") + protected RDoubleVector cumsumDouble(RAbstractDoubleVector x, + @Cached("x.access()") VectorAccess xAccess) { + try (SequentialIterator iter = xAccess.access(x)) { + double[] array = new double[xAccess.getLength(iter)]; + double prev = 0; + while (xAccess.next(iter)) { + double value = xAccess.getDouble(iter); + if (xAccess.na.check(value)) { + Arrays.fill(array, iter.getIndex(), array.length, RRuntime.DOUBLE_NA); + break; + } + if (xAccess.na.checkNAorNaN(value)) { + Arrays.fill(array, iter.getIndex(), array.length, Double.NaN); + break; + } + prev = add.op(prev, value); + assert !RRuntime.isNA(prev) : "double addition should not introduce NAs"; + array[iter.getIndex()] = prev; } - res[i] = prev; + return RDataFactory.createDoubleVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x)); } - if (!na.neverSeenNA()) { - Arrays.fill(res, i, res.length, RRuntime.INT_NA); - } - return RDataFactory.createIntVector(res, na.neverSeenNA(), getNamesNode.getNames(arg)); } - @Specialization - protected RComplexVector cumsum(RAbstractComplexVector arg) { - double[] res = new double[arg.getLength() * 2]; - RComplex prev = RDataFactory.createComplex(0.0, 0.0); - na.enable(true); - for (int i = 0; i < arg.getLength(); i++) { - prev = add.op(prev.getRealPart(), prev.getImaginaryPart(), arg.getDataAt(i).getRealPart(), arg.getDataAt(i).getImaginaryPart()); - if (na.check(arg.getDataAt(i))) { - Arrays.fill(res, 2 * i, res.length, RRuntime.DOUBLE_NA); - break; + @Specialization(replaces = "cumsumDouble") + protected RDoubleVector cumsumDoubleGeneric(RAbstractDoubleVector x) { + return cumsumDouble(x, x.slowPathAccess()); + } + + @Specialization(guards = "xAccess.supports(x)") + protected RComplexVector cumsumComplex(RAbstractComplexVector x, + @Cached("x.access()") VectorAccess xAccess) { + try (SequentialIterator iter = xAccess.access(x)) { + double[] array = new double[xAccess.getLength(iter) * 2]; + RComplex prev = RDataFactory.createComplex(0, 0); + while (xAccess.next(iter)) { + double real = xAccess.getComplexR(iter); + double imag = xAccess.getComplexI(iter); + if (xAccess.na.check(real, imag)) { + Arrays.fill(array, 2 * iter.getIndex(), array.length, RRuntime.DOUBLE_NA); + break; + } + prev = add.op(prev.getRealPart(), prev.getImaginaryPart(), real, imag); + assert !RRuntime.isNA(prev) : "complex addition should not introduce NAs"; + array[iter.getIndex() * 2] = prev.getRealPart(); + array[iter.getIndex() * 2 + 1] = prev.getImaginaryPart(); } - res[2 * i] = prev.getRealPart(); - res[2 * i + 1] = prev.getImaginaryPart(); + return RDataFactory.createComplexVector(array, xAccess.na.neverSeenNA(), getNamesNode.getNames(x)); } - return RDataFactory.createComplexVector(res, na.neverSeenNA(), getNamesNode.getNames(arg)); + } + + @Specialization(replaces = "cumsumComplex") + protected RComplexVector cumsumComplexGeneric(RAbstractComplexVector x) { + return cumsumComplex(x, x.slowPathAccess()); } } -- GitLab