Skip to content
Snippets Groups Projects
Commit 65c3be4e authored by stepan's avatar stepan
Browse files

User VectorIterator in RandFunctionsNodes

parent 8ad0de48
No related branches found
No related tags found
No related merge requests found
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
package com.oracle.truffle.r.library.stats; package com.oracle.truffle.r.library.stats;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue;
import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS; import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS;
import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER;
import java.util.Arrays; import java.util.Arrays;
...@@ -43,6 +43,7 @@ import com.oracle.truffle.r.runtime.data.RDouble; ...@@ -43,6 +43,7 @@ import com.oracle.truffle.r.runtime.data.RDouble;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorIterator;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction1_Double; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction1_Double;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction2_Double; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction2_Double;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction3_Double; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction3_Double;
...@@ -161,6 +162,9 @@ public final class RandFunctionsNodes { ...@@ -161,6 +162,9 @@ public final class RandFunctionsNodes {
protected abstract static class RandFunctionIntExecutorNode extends RandFunctionExecutorBase { protected abstract static class RandFunctionIntExecutorNode extends RandFunctionExecutorBase {
@Child private RandFunction3_Double function; @Child private RandFunction3_Double function;
@Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround();
protected RandFunctionIntExecutorNode(RandFunction3_Double function) { protected RandFunctionIntExecutorNode(RandFunction3_Double function) {
this.function = function; this.function = function;
...@@ -180,13 +184,16 @@ public final class RandFunctionsNodes { ...@@ -180,13 +184,16 @@ public final class RandFunctionsNodes {
return RDataFactory.createIntVector(nansResult, false); return RDataFactory.createIntVector(nansResult, false);
} }
Object aIt = aIterator.init(a);
Object bIt = bIterator.init(b);
Object cIt = cIterator.init(c);
boolean nans = false; boolean nans = false;
int[] result = new int[length]; int[] result = new int[length];
nodeData.loopConditionProfile.profileCounted(length); nodeData.loopConditionProfile.profileCounted(length);
for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) {
double aValue = a.getDataAt(i % aLength); double aValue = aIterator.next(a, aIt);
double bValue = b.getDataAt(i % bLength); double bValue = bIterator.next(b, bIt);
double cValue = c.getDataAt(i % cLength); double cValue = cIterator.next(c, cIt);
double value = function.execute(aValue, bValue, cValue, randProvider); double value = function.execute(aValue, bValue, cValue, randProvider);
if (Double.isNaN(value) || value <= Integer.MIN_VALUE || value > Integer.MAX_VALUE) { if (Double.isNaN(value) || value <= Integer.MIN_VALUE || value > Integer.MAX_VALUE) {
nodeData.nan.enter(); nodeData.nan.enter();
...@@ -206,6 +213,9 @@ public final class RandFunctionsNodes { ...@@ -206,6 +213,9 @@ public final class RandFunctionsNodes {
protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionExecutorBase { protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionExecutorBase {
@Child private RandFunction3_Double function; @Child private RandFunction3_Double function;
@Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround();
@Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround();
protected RandFunctionDoubleExecutorNode(RandFunction3_Double function) { protected RandFunctionDoubleExecutorNode(RandFunction3_Double function) {
this.function = function; this.function = function;
...@@ -225,14 +235,17 @@ public final class RandFunctionsNodes { ...@@ -225,14 +235,17 @@ public final class RandFunctionsNodes {
return RDataFactory.createDoubleVector(nansResult, false); return RDataFactory.createDoubleVector(nansResult, false);
} }
Object aIt = aIterator.init(a);
Object bIt = bIterator.init(b);
Object cIt = cIterator.init(c);
boolean nans = false; boolean nans = false;
double[] result; double[] result;
result = new double[length]; result = new double[length];
nodeData.loopConditionProfile.profileCounted(length); nodeData.loopConditionProfile.profileCounted(length);
for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) {
double aValue = a.getDataAt(i % aLength); double aValue = aIterator.next(a, aIt);
double bValue = b.getDataAt(i % bLength); double bValue = bIterator.next(b, bIt);
double cValue = c.getDataAt(i % cLength); double cValue = cIterator.next(c, cIt);
double value = function.execute(aValue, bValue, cValue, randProvider); double value = function.execute(aValue, bValue, cValue, randProvider);
if (Double.isNaN(value) || RRuntime.isNA(value)) { if (Double.isNaN(value) || RRuntime.isNA(value)) {
nodeData.nan.enter(); nodeData.nan.enter();
......
...@@ -334,6 +334,10 @@ public abstract class VectorIterator<T> extends Node { ...@@ -334,6 +334,10 @@ public abstract class VectorIterator<T> extends Node {
public static Double create() { public static Double create() {
return new Double(false); return new Double(false);
} }
public static Double createWrapAround() {
return new Double(true);
}
} }
public static final class Logical extends VectorIterator<Byte> { public static final class Logical extends VectorIterator<Byte> {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment