diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java new file mode 100644 index 0000000000000000000000000000000000000000..d7e72dd1f0745a5be4e65d8488bc54fdaee65656 --- /dev/null +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java @@ -0,0 +1,157 @@ +/* + * This material is distributed under the GNU General Public License + * Version 2. You may review the terms of this license at + * http://www.gnu.org/licenses/gpl-2.0.html + * + * Copyright (c) 1995, 1996, 1997 Robert Gentleman and Ross Ihaka + * Copyright (c) 1998-2013, The R Core Team + * Copyright (c) 2003-2015, The R Foundation + * Copyright (c) 2016, 2016, Oracle and/or its affiliates + * + * All rights reserved. + */ +package com.oracle.truffle.r.library.stats; + +import static com.oracle.truffle.r.library.stats.MathConstants.M_LN2; +import static com.oracle.truffle.r.library.stats.StatsUtil.DBL_MAX_EXP; +import static com.oracle.truffle.r.library.stats.StatsUtil.fmax2; +import static com.oracle.truffle.r.library.stats.StatsUtil.fmin2; + +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; +import com.oracle.truffle.r.runtime.rng.RandomNumberNode; + +public final class RBeta implements RandFunction2_Double { + + private static final double expmax = (DBL_MAX_EXP * M_LN2); /* = log(DBL_MAX) */ + + @Override + public double evaluate(int index, double aa, double bb, double random, RandomNumberNode randomNode) { + if (Double.isNaN(aa) || Double.isNaN(bb) || aa < 0. || bb < 0.) { + StatsUtil.mlError(); + } + if (!Double.isFinite(aa) && !Double.isFinite(bb)) { // a = b = Inf : all mass at 1/2 + return 0.5; + } + if (aa == 0. && bb == 0.) { // point mass 1/2 at each of {0,1} : + return (randomNode.executeSingleDouble() < 0.5) ? 0. : 1.; + } + // now, at least one of a, b is finite and positive + if (!Double.isFinite(aa) || bb == 0.) { + return 1.0; + } + if (!Double.isFinite(bb) || aa == 0.) { + return 0.0; + } + + double a; + double b; + double r; + double s; + double t; + double u1; + double u2; + double v = 0; + double w = 0; + double y; + double z; + double olda = -1.0; + double oldb = -1.0; + + double beta = 0; + double gamma = 1; + double delta; + double k1 = 0; + double k2 = 0; + + /* Test if we need new "initializing" */ + boolean qsame = (olda == aa) && (oldb == bb); + if (!qsame) { + olda = aa; + oldb = bb; + } + + a = fmin2(aa, bb); + b = fmax2(aa, bb); /* a <= b */ + double alpha = a + b; + + if (a <= 1.0) { /* --- Algorithm BC --- */ + /* changed notation, now also a <= b (was reversed) */ + if (!qsame) { /* initialize */ + beta = 1.0 / a; + delta = 1.0 + b - a; + k1 = delta * (0.0138889 + 0.0416667 * a) / (b * beta - 0.777778); + k2 = 0.25 + (0.5 + 0.25 / delta) * a; + } + /* FIXME: "do { } while()", but not trivially because of "continue"s: */ + for (;;) { + u1 = randomNode.executeSingleDouble(); + u2 = randomNode.executeSingleDouble(); + if (u1 < 0.5) { + y = u1 * u2; + z = u1 * y; + if (0.25 * u2 + z - y >= k1) { + continue; + } + } else { + z = u1 * u1 * u2; + if (z <= 0.25) { + v = beta * Math.log(u1 / (1.0 - u1)); + w = wFromU1Bet(b, v, w); + break; + } + if (z >= k2) { + continue; + } + } + + v = beta * Math.log(u1 / (1.0 - u1)); + w = wFromU1Bet(b, v, w); + + if (alpha * (Math.log(alpha / (a + w)) + v) - 1.3862944 >= Math.log(z)) { + break; + } + } + return (aa == a) ? a / (a + w) : w / (a + w); + + } else { /* Algorithm BB */ + + if (!qsame) { /* initialize */ + beta = Math.sqrt((alpha - 2.0) / (2.0 * a * b - alpha)); + gamma = a + 1.0 / beta; + } + do { + u1 = randomNode.executeSingleDouble(); + u2 = randomNode.executeSingleDouble(); + + v = beta * Math.log(u1 / (1.0 - u1)); + w = wFromU1Bet(a, v, w); + + z = u1 * u1 * u2; + r = gamma * v - 1.3862944; + s = a + r - w; + if (s + 2.609438 >= 5.0 * z) { + break; + } + t = Math.log(z); + if (s > t) { + break; + } + } while (r + alpha * Math.log(alpha / (b + w)) < t); + + return (aa != a) ? b / (b + w) : w / (b + w); + } + } + + private static double wFromU1Bet(double aa, double v, double w) { + if (v <= expmax) { + w = aa * Math.exp(v); + if (!Double.isFinite(w)) { + w = Double.MAX_VALUE; + } + } else { + w = Double.MAX_VALUE; + } + return w; + } + +} diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java index 55511aac7dee9757d8d5ee5bab6b7ff311f449ca..76fe2473335bca7fd404efc89a49945704360171 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java @@ -119,6 +119,10 @@ public class StatsUtil { return giveLog ? -0.5 * Math.log(f) + x : Math.exp(x) / Math.sqrt(f); } + // + // GNUR from fmin2.c and fmax2 + // + public static double fmax2(double x, double y) { if (Double.isNaN(x) || Double.isNaN(y)) { return x + y; @@ -126,6 +130,13 @@ public class StatsUtil { return (x < y) ? y : x; } + public static double fmin2(double x, double y) { + if (Double.isNaN(x) || Double.isNaN(y)) { + return x + y; + } + return (x < y) ? x : y; + } + // // GNUR from expm1.c // diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java index 9b49dab85d20085b3ddb4ed9dc1916e290d87afe..291965766cadbd5aaf0566d5ebcf958fa2659b2c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java @@ -51,6 +51,7 @@ import com.oracle.truffle.r.library.stats.Pf; import com.oracle.truffle.r.library.stats.Pnorm; import com.oracle.truffle.r.library.stats.Qbinom; import com.oracle.truffle.r.library.stats.Qnorm; +import com.oracle.truffle.r.library.stats.RBeta; import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory; import com.oracle.truffle.r.library.stats.Rbinom; import com.oracle.truffle.r.library.stats.Rnorm; @@ -362,6 +363,8 @@ public class ForeignFunctions { return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new Rnorm()); case "runif": return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new Runif()); + case "rbeta": + return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RBeta()); case "qgamma": return StatsFunctionsFactory.Function3_2NodeGen.create(new QgammaFunc()); case "dbinom": diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 998aa5bca342d0d75f878d80a9bc810dc88baa9c..873db69b5d38c28f7a9ee70428f52258ef62193a 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -111006,6 +111006,16 @@ Error: 'x' is NULL Warning message: In qgamma(10, 1) : NaNs produced +##com.oracle.truffle.r.test.library.stats.TestExternal_rbeta.testRbeta# +#set.seed(42); rbeta(10, 10, 10) + [1] 0.4282247 0.5459560 0.5805863 0.5512005 0.4866080 0.6987626 0.4880555 + [8] 0.7691043 0.4920874 0.6702352 + +##com.oracle.truffle.r.test.library.stats.TestExternal_rbeta.testRbeta# +#set.seed(42); rbeta(10, c(0.1, 2:10), c(0.1, 0.5, 0.9, 3:5)) + [1] 0.002930982 0.969019187 0.872817723 0.593769928 0.260911852 0.561458988 + [7] 1.000000000 0.929063923 0.991793861 0.914489454 + ##com.oracle.truffle.r.test.library.stats.TestExternal_rbinom.testRbinom# #set.seed(42); rbinom('10', 10, 0.5) [1] 7 7 4 7 6 5 6 3 6 6 diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbeta.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbeta.java new file mode 100644 index 0000000000000000000000000000000000000000..f4ab519fd490edcbcb89bb8f9e1f1e6a94b2cf7b --- /dev/null +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbeta.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.test.library.stats; + +import org.junit.Test; + +import com.oracle.truffle.r.test.TestBase; + +public class TestExternal_rbeta extends TestBase { + @Test + public void testRbeta() { + assertEval("set.seed(42); rbeta(10, 10, 10)"); + assertEval("set.seed(42); rbeta(10, c(0.1, 2:10), c(0.1, 0.5, 0.9, 3:5))"); + } +} diff --git a/mx.fastr/copyrights/overrides b/mx.fastr/copyrights/overrides index b2774bbe81df2fcdb8776e053ecbdb3aec5e15f4..e09d6f15e12a8430210f80dc35b8ccd00265194f 100644 --- a/mx.fastr/copyrights/overrides +++ b/mx.fastr/copyrights/overrides @@ -29,6 +29,7 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/grid/GridFunctions com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/methods/MethodsListDispatch.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/methods/Slot.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Arithmetic.java,gnu_r_gentleman_ihaka.copyright +com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java,gnu_r_gentleman_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/CompleteCases.java,gnu_r_gentleman_ihaka2.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Covcor.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Dbinom.java,gnu_r.copyright