From f7a452db82520bd60e948f2c033d6855e7eeb5ba Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Wed, 23 Nov 2016 11:31:56 +0100
Subject: [PATCH] implement rbeta

---
 .../oracle/truffle/r/library/stats/RBeta.java | 157 ++++++++++++++++++
 .../truffle/r/library/stats/StatsUtil.java    |  11 ++
 .../base/foreign/ForeignFunctions.java        |   3 +
 .../truffle/r/test/ExpectedTestOutput.test    |  10 ++
 .../library/stats/TestExternal_rbeta.java     |  35 ++++
 mx.fastr/copyrights/overrides                 |   1 +
 6 files changed, 217 insertions(+)
 create mode 100644 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java
 create mode 100644 com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbeta.java

diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java
new file mode 100644
index 0000000000..d7e72dd1f0
--- /dev/null
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java
@@ -0,0 +1,157 @@
+/*
+ * This material is distributed under the GNU General Public License
+ * Version 2. You may review the terms of this license at
+ * http://www.gnu.org/licenses/gpl-2.0.html
+ *
+ * Copyright (c) 1995, 1996, 1997  Robert Gentleman and Ross Ihaka
+ * Copyright (c) 1998-2013, The R Core Team
+ * Copyright (c) 2003-2015, The R Foundation
+ * Copyright (c) 2016, 2016, Oracle and/or its affiliates
+ *
+ * All rights reserved.
+ */
+package com.oracle.truffle.r.library.stats;
+
+import static com.oracle.truffle.r.library.stats.MathConstants.M_LN2;
+import static com.oracle.truffle.r.library.stats.StatsUtil.DBL_MAX_EXP;
+import static com.oracle.truffle.r.library.stats.StatsUtil.fmax2;
+import static com.oracle.truffle.r.library.stats.StatsUtil.fmin2;
+
+import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double;
+import com.oracle.truffle.r.runtime.rng.RandomNumberNode;
+
+public final class RBeta implements RandFunction2_Double {
+
+    private static final double expmax = (DBL_MAX_EXP * M_LN2); /* = log(DBL_MAX) */
+
+    @Override
+    public double evaluate(int index, double aa, double bb, double random, RandomNumberNode randomNode) {
+        if (Double.isNaN(aa) || Double.isNaN(bb) || aa < 0. || bb < 0.) {
+            StatsUtil.mlError();
+        }
+        if (!Double.isFinite(aa) && !Double.isFinite(bb)) { // a = b = Inf : all mass at 1/2
+            return 0.5;
+        }
+        if (aa == 0. && bb == 0.) { // point mass 1/2 at each of {0,1} :
+            return (randomNode.executeSingleDouble() < 0.5) ? 0. : 1.;
+        }
+        // now, at least one of a, b is finite and positive
+        if (!Double.isFinite(aa) || bb == 0.) {
+            return 1.0;
+        }
+        if (!Double.isFinite(bb) || aa == 0.) {
+            return 0.0;
+        }
+
+        double a;
+        double b;
+        double r;
+        double s;
+        double t;
+        double u1;
+        double u2;
+        double v = 0;
+        double w = 0;
+        double y;
+        double z;
+        double olda = -1.0;
+        double oldb = -1.0;
+
+        double beta = 0;
+        double gamma = 1;
+        double delta;
+        double k1 = 0;
+        double k2 = 0;
+
+        /* Test if we need new "initializing" */
+        boolean qsame = (olda == aa) && (oldb == bb);
+        if (!qsame) {
+            olda = aa;
+            oldb = bb;
+        }
+
+        a = fmin2(aa, bb);
+        b = fmax2(aa, bb); /* a <= b */
+        double alpha = a + b;
+
+        if (a <= 1.0) { /* --- Algorithm BC --- */
+            /* changed notation, now also a <= b (was reversed) */
+            if (!qsame) { /* initialize */
+                beta = 1.0 / a;
+                delta = 1.0 + b - a;
+                k1 = delta * (0.0138889 + 0.0416667 * a) / (b * beta - 0.777778);
+                k2 = 0.25 + (0.5 + 0.25 / delta) * a;
+            }
+            /* FIXME: "do { } while()", but not trivially because of "continue"s: */
+            for (;;) {
+                u1 = randomNode.executeSingleDouble();
+                u2 = randomNode.executeSingleDouble();
+                if (u1 < 0.5) {
+                    y = u1 * u2;
+                    z = u1 * y;
+                    if (0.25 * u2 + z - y >= k1) {
+                        continue;
+                    }
+                } else {
+                    z = u1 * u1 * u2;
+                    if (z <= 0.25) {
+                        v = beta * Math.log(u1 / (1.0 - u1));
+                        w = wFromU1Bet(b, v, w);
+                        break;
+                    }
+                    if (z >= k2) {
+                        continue;
+                    }
+                }
+
+                v = beta * Math.log(u1 / (1.0 - u1));
+                w = wFromU1Bet(b, v, w);
+
+                if (alpha * (Math.log(alpha / (a + w)) + v) - 1.3862944 >= Math.log(z)) {
+                    break;
+                }
+            }
+            return (aa == a) ? a / (a + w) : w / (a + w);
+
+        } else { /* Algorithm BB */
+
+            if (!qsame) { /* initialize */
+                beta = Math.sqrt((alpha - 2.0) / (2.0 * a * b - alpha));
+                gamma = a + 1.0 / beta;
+            }
+            do {
+                u1 = randomNode.executeSingleDouble();
+                u2 = randomNode.executeSingleDouble();
+
+                v = beta * Math.log(u1 / (1.0 - u1));
+                w = wFromU1Bet(a, v, w);
+
+                z = u1 * u1 * u2;
+                r = gamma * v - 1.3862944;
+                s = a + r - w;
+                if (s + 2.609438 >= 5.0 * z) {
+                    break;
+                }
+                t = Math.log(z);
+                if (s > t) {
+                    break;
+                }
+            } while (r + alpha * Math.log(alpha / (b + w)) < t);
+
+            return (aa != a) ? b / (b + w) : w / (b + w);
+        }
+    }
+
+    private static double wFromU1Bet(double aa, double v, double w) {
+        if (v <= expmax) {
+            w = aa * Math.exp(v);
+            if (!Double.isFinite(w)) {
+                w = Double.MAX_VALUE;
+            }
+        } else {
+            w = Double.MAX_VALUE;
+        }
+        return w;
+    }
+
+}
diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java
index 55511aac7d..76fe247333 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java
@@ -119,6 +119,10 @@ public class StatsUtil {
         return giveLog ? -0.5 * Math.log(f) + x : Math.exp(x) / Math.sqrt(f);
     }
 
+    //
+    // GNUR from fmin2.c and fmax2
+    //
+
     public static double fmax2(double x, double y) {
         if (Double.isNaN(x) || Double.isNaN(y)) {
             return x + y;
@@ -126,6 +130,13 @@ public class StatsUtil {
         return (x < y) ? y : x;
     }
 
+    public static double fmin2(double x, double y) {
+        if (Double.isNaN(x) || Double.isNaN(y)) {
+            return x + y;
+        }
+        return (x < y) ? x : y;
+    }
+
     //
     // GNUR from expm1.c
     //
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java
index 9b49dab85d..291965766c 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java
@@ -51,6 +51,7 @@ import com.oracle.truffle.r.library.stats.Pf;
 import com.oracle.truffle.r.library.stats.Pnorm;
 import com.oracle.truffle.r.library.stats.Qbinom;
 import com.oracle.truffle.r.library.stats.Qnorm;
+import com.oracle.truffle.r.library.stats.RBeta;
 import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory;
 import com.oracle.truffle.r.library.stats.Rbinom;
 import com.oracle.truffle.r.library.stats.Rnorm;
@@ -362,6 +363,8 @@ public class ForeignFunctions {
                     return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new Rnorm());
                 case "runif":
                     return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new Runif());
+                case "rbeta":
+                    return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RBeta());
                 case "qgamma":
                     return StatsFunctionsFactory.Function3_2NodeGen.create(new QgammaFunc());
                 case "dbinom":
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
index 998aa5bca3..873db69b5d 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
@@ -111006,6 +111006,16 @@ Error: 'x' is NULL
 Warning message:
 In qgamma(10, 1) : NaNs produced
 
+##com.oracle.truffle.r.test.library.stats.TestExternal_rbeta.testRbeta#
+#set.seed(42); rbeta(10, 10, 10)
+ [1] 0.4282247 0.5459560 0.5805863 0.5512005 0.4866080 0.6987626 0.4880555
+ [8] 0.7691043 0.4920874 0.6702352
+
+##com.oracle.truffle.r.test.library.stats.TestExternal_rbeta.testRbeta#
+#set.seed(42); rbeta(10, c(0.1, 2:10), c(0.1, 0.5, 0.9, 3:5))
+ [1] 0.002930982 0.969019187 0.872817723 0.593769928 0.260911852 0.561458988
+ [7] 1.000000000 0.929063923 0.991793861 0.914489454
+
 ##com.oracle.truffle.r.test.library.stats.TestExternal_rbinom.testRbinom#
 #set.seed(42); rbinom('10', 10, 0.5)
  [1] 7 7 4 7 6 5 6 3 6 6
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbeta.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbeta.java
new file mode 100644
index 0000000000..f4ab519fd4
--- /dev/null
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbeta.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+package com.oracle.truffle.r.test.library.stats;
+
+import org.junit.Test;
+
+import com.oracle.truffle.r.test.TestBase;
+
+public class TestExternal_rbeta extends TestBase {
+    @Test
+    public void testRbeta() {
+        assertEval("set.seed(42); rbeta(10, 10, 10)");
+        assertEval("set.seed(42); rbeta(10, c(0.1, 2:10), c(0.1, 0.5, 0.9, 3:5))");
+    }
+}
diff --git a/mx.fastr/copyrights/overrides b/mx.fastr/copyrights/overrides
index b2774bbe81..e09d6f15e1 100644
--- a/mx.fastr/copyrights/overrides
+++ b/mx.fastr/copyrights/overrides
@@ -29,6 +29,7 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/grid/GridFunctions
 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/methods/MethodsListDispatch.java,gnu_r.copyright
 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/methods/Slot.java,gnu_r.copyright
 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Arithmetic.java,gnu_r_gentleman_ihaka.copyright
+com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java,gnu_r_gentleman_ihaka.copyright
 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/CompleteCases.java,gnu_r_gentleman_ihaka2.copyright
 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Covcor.java,gnu_r.copyright
 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Dbinom.java,gnu_r.copyright
-- 
GitLab