Skip to content
Snippets Groups Projects
Commit 65c3be4e authored by stepan's avatar stepan
Browse files

User VectorIterator in RandFunctionsNodes

parent 8ad0de48
No related branches found
No related tags found
No related merge requests found
......@@ -13,11 +13,11 @@
package com.oracle.truffle.r.library.stats;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue;
import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS;
import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER;
import java.util.Arrays;
......@@ -43,6 +43,7 @@ import com.oracle.truffle.r.runtime.data.RDouble;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorIterator;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction1_Double;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction2_Double;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction3_Double;
......@@ -161,6 +162,9 @@ public final class RandFunctionsNodes {
protected abstract static class RandFunctionIntExecutorNode extends RandFunctionExecutorBase {
@Child private RandFunction3_Double function;
@Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround();
protected RandFunctionIntExecutorNode(RandFunction3_Double function) {
this.function = function;
......@@ -180,13 +184,16 @@ public final class RandFunctionsNodes {
return RDataFactory.createIntVector(nansResult, false);
}
Object aIt = aIterator.init(a);
Object bIt = bIterator.init(b);
Object cIt = cIterator.init(c);
boolean nans = false;
int[] result = new int[length];
nodeData.loopConditionProfile.profileCounted(length);
for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) {
double aValue = a.getDataAt(i % aLength);
double bValue = b.getDataAt(i % bLength);
double cValue = c.getDataAt(i % cLength);
double aValue = aIterator.next(a, aIt);
double bValue = bIterator.next(b, bIt);
double cValue = cIterator.next(c, cIt);
double value = function.execute(aValue, bValue, cValue, randProvider);
if (Double.isNaN(value) || value <= Integer.MIN_VALUE || value > Integer.MAX_VALUE) {
nodeData.nan.enter();
......@@ -206,6 +213,9 @@ public final class RandFunctionsNodes {
protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionExecutorBase {
@Child private RandFunction3_Double function;
@Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround();
protected RandFunctionDoubleExecutorNode(RandFunction3_Double function) {
this.function = function;
......@@ -225,14 +235,17 @@ public final class RandFunctionsNodes {
return RDataFactory.createDoubleVector(nansResult, false);
}
Object aIt = aIterator.init(a);
Object bIt = bIterator.init(b);
Object cIt = cIterator.init(c);
boolean nans = false;
double[] result;
result = new double[length];
nodeData.loopConditionProfile.profileCounted(length);
for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) {
double aValue = a.getDataAt(i % aLength);
double bValue = b.getDataAt(i % bLength);
double cValue = c.getDataAt(i % cLength);
double aValue = aIterator.next(a, aIt);
double bValue = bIterator.next(b, bIt);
double cValue = cIterator.next(c, cIt);
double value = function.execute(aValue, bValue, cValue, randProvider);
if (Double.isNaN(value) || RRuntime.isNA(value)) {
nodeData.nan.enter();
......
......@@ -334,6 +334,10 @@ public abstract class VectorIterator<T> extends Node {
public static Double create() {
return new Double(false);
}
public static Double createWrapAround() {
return new Double(true);
}
}
public static final class Logical extends VectorIterator<Byte> {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment