From 65c3be4e6d561ea74062b15e2a3a213b7bad2a14 Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Fri, 22 Sep 2017 15:57:57 +0200
Subject: [PATCH] User VectorIterator in RandFunctionsNodes

---
 .../r/library/stats/RandFunctionsNodes.java   | 31 +++++++++++++------
 .../r/runtime/data/nodes/VectorIterator.java  |  4 +++
 2 files changed, 26 insertions(+), 9 deletions(-)

diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java
index 93cd847a3d..f8b8ad866b 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java
@@ -13,11 +13,11 @@
 
 package com.oracle.truffle.r.library.stats;
 
-import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
-import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue;
-import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
 import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS;
+import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER;
 
 import java.util.Arrays;
 
@@ -43,6 +43,7 @@ import com.oracle.truffle.r.runtime.data.RDouble;
 import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
+import com.oracle.truffle.r.runtime.data.nodes.VectorIterator;
 import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction1_Double;
 import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction2_Double;
 import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction3_Double;
@@ -161,6 +162,9 @@ public final class RandFunctionsNodes {
 
     protected abstract static class RandFunctionIntExecutorNode extends RandFunctionExecutorBase {
         @Child private RandFunction3_Double function;
+        @Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround();
+        @Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround();
+        @Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround();
 
         protected RandFunctionIntExecutorNode(RandFunction3_Double function) {
             this.function = function;
@@ -180,13 +184,16 @@ public final class RandFunctionsNodes {
                 return RDataFactory.createIntVector(nansResult, false);
             }
 
+            Object aIt = aIterator.init(a);
+            Object bIt = bIterator.init(b);
+            Object cIt = cIterator.init(c);
             boolean nans = false;
             int[] result = new int[length];
             nodeData.loopConditionProfile.profileCounted(length);
             for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) {
-                double aValue = a.getDataAt(i % aLength);
-                double bValue = b.getDataAt(i % bLength);
-                double cValue = c.getDataAt(i % cLength);
+                double aValue = aIterator.next(a, aIt);
+                double bValue = bIterator.next(b, bIt);
+                double cValue = cIterator.next(c, cIt);
                 double value = function.execute(aValue, bValue, cValue, randProvider);
                 if (Double.isNaN(value) || value <= Integer.MIN_VALUE || value > Integer.MAX_VALUE) {
                     nodeData.nan.enter();
@@ -206,6 +213,9 @@ public final class RandFunctionsNodes {
 
     protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionExecutorBase {
         @Child private RandFunction3_Double function;
+        @Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround();
+        @Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround();
+        @Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround();
 
         protected RandFunctionDoubleExecutorNode(RandFunction3_Double function) {
             this.function = function;
@@ -225,14 +235,17 @@ public final class RandFunctionsNodes {
                 return RDataFactory.createDoubleVector(nansResult, false);
             }
 
+            Object aIt = aIterator.init(a);
+            Object bIt = bIterator.init(b);
+            Object cIt = cIterator.init(c);
             boolean nans = false;
             double[] result;
             result = new double[length];
             nodeData.loopConditionProfile.profileCounted(length);
             for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) {
-                double aValue = a.getDataAt(i % aLength);
-                double bValue = b.getDataAt(i % bLength);
-                double cValue = c.getDataAt(i % cLength);
+                double aValue = aIterator.next(a, aIt);
+                double bValue = bIterator.next(b, bIt);
+                double cValue = cIterator.next(c, cIt);
                 double value = function.execute(aValue, bValue, cValue, randProvider);
                 if (Double.isNaN(value) || RRuntime.isNA(value)) {
                     nodeData.nan.enter();
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java
index 350f80170e..f2ffc6e80b 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/data/nodes/VectorIterator.java
@@ -334,6 +334,10 @@ public abstract class VectorIterator<T> extends Node {
         public static Double create() {
             return new Double(false);
         }
+
+        public static Double createWrapAround() {
+            return new Double(true);
+        }
     }
 
     public static final class Logical extends VectorIterator<Byte> {
-- 
GitLab