From d9033b07a6f5574f5d44ac69c0b91c40a99c6fa0 Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Fri, 25 Nov 2016 23:06:19 +0100 Subject: [PATCH] Implement C_rnchisq adding support for rchisq with 3 arguments --- .../truffle/r/library/stats/RNchisq.java | 41 +++ .../oracle/truffle/r/library/stats/RPois.java | 276 ++++++++++++++++++ .../truffle/r/library/stats/StatsUtil.java | 7 + .../base/foreign/ForeignFunctions.java | 3 + .../stats/TestRandGenerationFunctions.java | 2 +- mx.fastr/copyrights/overrides | 2 + 6 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java create mode 100644 com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java new file mode 100644 index 0000000000..da55025e1a --- /dev/null +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java @@ -0,0 +1,41 @@ +/* + * This material is distributed under the GNU General Public License + * Version 2. You may review the terms of this license at + * http://www.gnu.org/licenses/gpl-2.0.html + * + * Copyright (c) 1995-2015, The R Core Team + * Copyright (c) 2015, The R Foundation + * Copyright (c) 2016, Oracle and/or its affiliates + * + * All rights reserved. + */ + +// TODO: fix copyright +package com.oracle.truffle.r.library.stats; + +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; + +public final class RNchisq implements RandFunction2_Double { + private final RGamma rgamma = new RGamma(); + + @Override + public double evaluate(double df, double lambda, RandomNumberProvider rand) { + if (!Double.isFinite(df) || !Double.isFinite(lambda) || df < 0. || lambda < 0.) { + return StatsUtil.mlError(); + } + + if (lambda == 0.) { + return (df == 0.) ? 0. : rgamma.evaluate(df / 2., 2., rand); + } else { + double r = RPois.rpois(lambda / 2., rand); + if (r > 0.) { + r = RChisq.rchisq(2. * r, rand); + } + if (df > 0.) { + r += rgamma.evaluate(df / 2., 2., rand); + } + return r; + } + } +} diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java new file mode 100644 index 0000000000..46566a17ff --- /dev/null +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java @@ -0,0 +1,276 @@ +/* + * This material is distributed under the GNU General Public License + * Version 2. You may review the terms of this license at + * http://www.gnu.org/licenses/gpl-2.0.html + * + * Copyright (C) 1998 Ross Ihaka + * Copyright (c) 1998--2008, The R Core Team + * Copyright (c) 2016, 2016, Oracle and/or its affiliates + * + * All rights reserved. + */ +package com.oracle.truffle.r.library.stats; + +import static com.oracle.truffle.r.library.stats.MathConstants.M_1_SQRT_2PI; + +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; + +public final class RPois { + + private static final double a0 = -0.5; + private static final double a1 = 0.3333333; + private static final double a2 = -0.2500068; + private static final double a3 = 0.2000118; + private static final double a4 = -0.1661269; + private static final double a5 = 0.1421878; + private static final double a6 = -0.1384794; + private static final double a7 = 0.1250060; + + private static final double one_7 = 0.1428571428571428571; + private static final double one_12 = 0.0833333333333333333; + private static final double one_24 = 0.0416666666666666667; + + /* Factorial Table (0:9)! */ + private static final double[] fact = new double[]{ + 1., 1., 2., 6., 24., 120., 720., 5040., 40320., 362880. + }; + + public static double rpois(double mu, RandomNumberProvider rand) { + + /* These are static --- persistent between calls for same mu : */ + // TODO: state variables + int l = 0; + int m = 0; + double b1; + double b2; + double c = 0; + double c0 = 0; + double c1 = 0; + double c2 = 0; + double c3 = 0; + double[] pp = new double[36]; + double p0 = 0; + double p = 0; + double q = 0; + double s = 0; + double d = 0; + double omega = 0; + double bigL = 0; /* integer "w/o overflow" */ + double muprev = 0.; + double muprev2 = 0.; /* , muold = 0. */ + + /* Local Vars [initialize some for -Wall]: */ + double del; + double difmuk = 0.; + double e = 0.; + double fk = 0.; + double fx; + double fy; + double g; + double px; + double py; + double t = 0; + double u = 0.; + double v; + double x; + double pois = -1.; + int k; + int kflag = 0; + boolean bigMu; + boolean newBigMu = false; + + if (!Double.isFinite(mu) || mu < 0) { + return StatsUtil.mlError(); + } + + if (mu <= 0.) { + return 0.; + } + + bigMu = mu >= 10.; + if (bigMu) { + newBigMu = false; + } + + if (!(bigMu && mu == muprev)) { /* maybe compute new persistent par.s */ + + if (bigMu) { + newBigMu = true; + /* + * Case A. (recalculation of s,d,l because mu has changed): The poisson + * probabilities pk exceed the discrete normal probabilities fk whenever k >= m(mu). + */ + muprev = mu; + s = Math.sqrt(mu); + d = 6. * mu * mu; + bigL = Math.floor(mu - 1.1484); + /* = an upper bound to m(mu) for all mu >= 10. */ + } else { /* Small mu ( < 10) -- not using normal approx. */ + + /* Case B. (start new table and calculate p0 if necessary) */ + + /* muprev = 0.;-* such that next time, mu != muprev .. */ + if (mu != muprev) { + muprev = mu; + m = Math.max(1, (int) mu); + l = 0; /* pp[] is already ok up to pp[l] */ + q = p0 = p = Math.exp(-mu); + } + + while (true) { + /* Step U. uniform sample for inversion method */ + u = rand.unifRand(); + if (u <= p0) { + return 0.; + } + + /* + * Step T. table comparison until the end pp[l] of the pp-table of cumulative + * poisson probabilities (0.458 > ~= pp[9](= 0.45792971447) for mu=10 ) + */ + if (l != 0) { + for (k = (u <= 0.458) ? 1 : Math.min(l, m); k <= l; k++) { + if (u <= pp[k]) { + return (double) k; + } + } + if (l == 35) { /* u > pp[35] */ + continue; + } + } + /* + * Step C. creation of new poisson probabilities p[l..] and their cumulatives q + * =: pp[k] + */ + l++; + for (k = l; k <= 35; k++) { + p *= mu / k; + q += p; + pp[k] = q; + if (u <= q) { + l = k; + return (double) k; + } + } + l = 35; + } /* end(repeat) */ + } /* mu < 10 */ + + } /* end {initialize persistent vars} */ + + /* Only if mu >= 10 : ----------------------- */ + + /* Step N. normal sample */ + g = mu + s * rand.unifRand(); /* norm_rand() ~ N(0,1), standard normal */ + + if (g >= 0.) { + pois = Math.floor(g); + /* Step I. immediate acceptance if pois is large enough */ + if (pois >= bigL) { + return pois; + } + /* Step S. squeeze acceptance */ + fk = pois; + difmuk = mu - fk; + u = rand.unifRand(); /* ~ U(0,1) - sample */ + if (d * u >= difmuk * difmuk * difmuk) { + return pois; + } + } + + /* + * Step P. preparations for steps Q and H. (recalculations of parameters if necessary) + */ + + if (newBigMu || mu != muprev2) { + /* + * Careful! muprev2 is not always == muprev because one might have exited in step I or S + */ + muprev2 = mu; + omega = M_1_SQRT_2PI / s; + /* + * The quantities b1, b2, c3, c2, c1, c0 are for the Hermite approximations to the + * discrete normal probabilities fk. + */ + + b1 = one_24 / mu; + b2 = 0.3 * b1 * b1; + c3 = one_7 * b1 * b2; + c2 = b2 - 15. * c3; + c1 = b1 - 6. * b2 + 45. * c3; + c0 = 1. - b1 + 3. * b2 - 15. * c3; + c = 0.1069 / mu; /* guarantees majorization by the 'hat'-function. */ + } + + boolean gotoStepF = false; + if (g >= 0.) { + /* 'Subroutine' F is called (kflag=0 for correct return) */ + kflag = 0; + gotoStepF = true; + // goto Step_F; + } + + while (true) { + if (!gotoStepF) { + /* Step E. Exponential Sample */ + e = rand.expRand(); /* ~ Exp(1) (standard exponential) */ + + /* + * sample t from the laplace 'hat' (if t <= -0.6744 then pk < fk for all mu >= 10.) + */ + u = 2 * rand.unifRand() - 1.; + t = 1.8 + StatsUtil.fsign(e, u); + } + + if (t > -0.6744 || gotoStepF) { + if (!gotoStepF) { + pois = Math.floor(mu + s * t); + fk = pois; + difmuk = mu - fk; + + /* 'subroutine' F is called (kflag=1 for correct return) */ + kflag = 1; + } + + // Step_F: /* 'subroutine' F : calculation of px,py,fx,fy. */ + gotoStepF = false; + + if (pois < 10) { /* use factorials from table fact[] */ + px = -mu; + py = Math.pow(mu, pois) / fact[(int) pois]; + } else { + /* + * Case pois >= 10 uses polynomial approximation a0-a7 for accuracy when + * advisable + */ + del = one_12 / fk; + del = del * (1. - 4.8 * del * del); + v = difmuk / fk; + if (TOMS708.fabs(v) <= 0.25) { + px = fk * v * v * (((((((a7 * v + a6) * v + a5) * v + a4) * + v + a3) * v + a2) * v + a1) * v + a0) - del; + } else { /* |v| > 1/4 */ + px = fk * Math.log(1. + v) - difmuk - del; + } + py = M_1_SQRT_2PI / Math.sqrt(fk); + } + x = (0.5 - difmuk) / s; + x *= x; /* x^2 */ + fx = -0.5 * x; + fy = omega * (((c3 * x + c2) * x + c1) * x + c0); + if (kflag > 0) { + /* Step H. Hat acceptance (E is repeated on rejection) */ + if (c * TOMS708.fabs(u) <= py * Math.exp(px + e) - fy * Math.exp(fx + e)) { + break; + } + } else { + /* Step Q. Quotient acceptance (rare case) */ + if (fy - u * fy <= py * Math.exp(px - fx)) { + break; + } + } + } /* t > -.67.. */ + } + return pois; + } +} diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java index 76fe247333..7d509e45bc 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java @@ -119,6 +119,13 @@ public class StatsUtil { return giveLog ? -0.5 * Math.log(f) + x : Math.exp(x) / Math.sqrt(f); } + public static double fsign(double x, double y) { + if (Double.isNaN(x) || Double.isNaN(y)) { + return x + y; + } + return ((y >= 0) ? TOMS708.fabs(x) : -TOMS708.fabs(x)); + } + // // GNUR from fmin2.c and fmax2 // diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java index 9af31ee9a7..cd17d41445 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java @@ -58,6 +58,7 @@ import com.oracle.truffle.r.library.stats.RBeta; import com.oracle.truffle.r.library.stats.RCauchy; import com.oracle.truffle.r.library.stats.RGamma; import com.oracle.truffle.r.library.stats.RLogis; +import com.oracle.truffle.r.library.stats.RNchisq; import com.oracle.truffle.r.library.stats.RWeibull; import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory; import com.oracle.truffle.r.library.stats.Rbinom; @@ -384,6 +385,8 @@ public class ForeignFunctions { return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RLogis()); case "rweibull": return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RWeibull()); + case "rnchisq": + return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RNchisq()); case "qgamma": return StatsFunctionsFactory.Function3_2NodeGen.create(new QgammaFunc()); case "dbinom": diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestRandGenerationFunctions.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestRandGenerationFunctions.java index 4dccd15e28..37f836aed5 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestRandGenerationFunctions.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestRandGenerationFunctions.java @@ -31,7 +31,7 @@ import com.oracle.truffle.r.test.TestBase; * tests for its specific corner cases if those are not covered here. */ public class TestRandGenerationFunctions extends TestBase { - private static final String[] FUNCTION2_NAMES = {"rnorm", "runif", "rgamma", "rbeta", "rcauchy", "rf", "rlogis", "rweibull"}; + private static final String[] FUNCTION2_NAMES = {"rnorm", "runif", "rgamma", "rbeta", "rcauchy", "rf", "rlogis", "rweibull", "rchisq"}; private static final String[] FUNCTION2_PARAMS = { "10, 10, 10", "20, c(-1, 0, 0.2, 2:5), c(-1, 0, 0.1, 0.9, 3)", diff --git a/mx.fastr/copyrights/overrides b/mx.fastr/copyrights/overrides index 8a4dcc46ef..3d0c71c008 100644 --- a/mx.fastr/copyrights/overrides +++ b/mx.fastr/copyrights/overrides @@ -48,6 +48,7 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Pnorm.java,g com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Qbinom.java,gnu_r_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Qnorm.java,gnu_r_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Random2.java,gnu_r.copyright +com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java,gnu_r_ihaka_core.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java,gnu_r_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rnorm.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/SplineFunctions.java,gnu_r_splines.copyright @@ -61,6 +62,7 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RGamma.java, com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RLogis.java,gnu_r_ihaka_core.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rf.java,gnu_r_ihaka_core.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RChisq.java,gnu_r_ihaka_core.copyright +com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RWeibull.java,gnu_r_ihaka_core.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/tools/DirChmod.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/tools/ToolsText.java,gnu_r.copyright -- GitLab