diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java index 93cd847a3d3b093e9e1a5afe3530ebe9a3290052..f8b8ad866ba0f07df2f543055be4ad89d74ca373 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java @@ -13,11 +13,11 @@ package com.oracle.truffle.r.library.stats; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue; -import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS; +import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; import java.util.Arrays; @@ -43,6 +43,7 @@ import com.oracle.truffle.r.runtime.data.RDouble; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorIterator; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction1_Double; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction2_Double; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction3_Double; @@ -161,6 +162,9 @@ public final class RandFunctionsNodes { protected abstract static class RandFunctionIntExecutorNode extends RandFunctionExecutorBase { @Child private RandFunction3_Double function; + @Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround(); + @Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround(); + @Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround(); protected RandFunctionIntExecutorNode(RandFunction3_Double function) { this.function = function; @@ -180,13 +184,16 @@ public final class RandFunctionsNodes { return RDataFactory.createIntVector(nansResult, false); } + Object aIt = aIterator.init(a); + Object bIt = bIterator.init(b); + Object cIt = cIterator.init(c); boolean nans = false; int[] result = new int[length]; nodeData.loopConditionProfile.profileCounted(length); for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { - double aValue = a.getDataAt(i % aLength); - double bValue = b.getDataAt(i % bLength); - double cValue = c.getDataAt(i % cLength); + double aValue = aIterator.next(a, aIt); + double bValue = bIterator.next(b, bIt); + double cValue = cIterator.next(c, cIt); double value = function.execute(aValue, bValue, cValue, randProvider); if (Double.isNaN(value) || value <= Integer.MIN_VALUE || value > Integer.MAX_VALUE) { nodeData.nan.enter(); @@ -206,6 +213,9 @@ public final class RandFunctionsNodes { protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionExecutorBase { @Child private RandFunction3_Double function; + @Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround(); + @Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround(); + @Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround(); protected RandFunctionDoubleExecutorNode(RandFunction3_Double function) { this.function = function; @@ -225,14 +235,17 @@ public final class RandFunctionsNodes { return RDataFactory.createDoubleVector(nansResult, false); } + Object aIt = aIterator.init(a); + Object bIt = bIterator.init(b); + Object cIt = cIterator.init(c); boolean nans = false; double[] result; result = new double[length]; nodeData.loopConditionProfile.profileCounted(length); for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { - double aValue = a.getDataAt(i % aLength); - double bValue = b.getDataAt(i % bLength); - double cValue = c.getDataAt(i % cLength); + double aValue = aIterator.next(a, aIt); + double bValue = bIterator.next(b, bIt); + double cValue = cIterator.next(c, cIt); double value = function.execute(aValue, bValue, cValue, randProvider); if (Double.isNaN(value) || RRuntime.isNA(value)) { nodeData.nan.enter(); diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java index 350f80170e76964ef8be607705618235e9cc32d5..f2ffc6e80be0be8bb4e4c0e5e6056f15d3f91d51 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java @@ -334,6 +334,10 @@ public abstract class VectorIterator<T> extends Node { public static Double create() { return new Double(false); } + + public static Double createWrapAround() { + return new Double(true); + } } public static final class Logical extends VectorIterator<Byte> {