From b2018cabc65f91c95735ee7e3efaf76d4e372ab4 Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Mon, 19 Dec 2016 20:10:11 +0100 Subject: [PATCH] RandGenerationFunctions fixed to be more partial evaluation friendly - individual random generation functions converted to nodes (with profiles) - DSL powered cache of the concrete type of current random number generator --- .../truffle/r/library/stats/Cauchy.java | 4 +- .../oracle/truffle/r/library/stats/DPQ.java | 16 +- .../oracle/truffle/r/library/stats/Exp.java | 10 +- .../oracle/truffle/r/library/stats/Geom.java | 10 +- .../truffle/r/library/stats/LogNormal.java | 6 +- .../oracle/truffle/r/library/stats/RBeta.java | 4 +- .../truffle/r/library/stats/RChisq.java | 6 +- .../truffle/r/library/stats/RGamma.java | 4 +- .../truffle/r/library/stats/RHyper.java | 6 +- .../truffle/r/library/stats/RLogis.java | 4 +- .../truffle/r/library/stats/RMultinom.java | 2 +- .../truffle/r/library/stats/RNbinomMu.java | 6 +- .../truffle/r/library/stats/RNchisq.java | 8 +- .../oracle/truffle/r/library/stats/RPois.java | 42 +- .../truffle/r/library/stats/RWeibull.java | 4 +- .../stats/RandGenerationFunctions.java | 381 ++++++++++-------- .../truffle/r/library/stats/Rbinom.java | 4 +- .../oracle/truffle/r/library/stats/Rf.java | 4 +- .../oracle/truffle/r/library/stats/Rnorm.java | 17 +- .../oracle/truffle/r/library/stats/Rt.java | 4 +- .../oracle/truffle/r/library/stats/Runif.java | 17 +- .../truffle/r/library/stats/Signrank.java | 7 +- .../truffle/r/library/stats/Wilcox.java | 7 +- .../foreign/CallAndExternalFunctions.java | 44 +- 24 files changed, 356 insertions(+), 261 deletions(-) diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cauchy.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cauchy.java index 3ab4aaf454..f316640b6f 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cauchy.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cauchy.java @@ -25,9 +25,9 @@ public final class Cauchy { // contains only static classes } - public static final class RCauchy implements RandFunction2_Double { + public static final class RCauchy extends RandFunction2_Double { @Override - public double evaluate(double location, double scale, RandomNumberProvider rand) { + public double execute(double location, double scale, RandomNumberProvider rand) { if (Double.isNaN(location) || !Double.isFinite(scale) || scale < 0) { return RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DPQ.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DPQ.java index 4925be2e76..24e3dfb71a 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DPQ.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DPQ.java @@ -12,7 +12,6 @@ package com.oracle.truffle.r.library.stats; import static com.oracle.truffle.r.library.stats.MathConstants.M_LN2; -import com.oracle.truffle.api.nodes.ControlFlowException; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; @@ -26,13 +25,19 @@ public final class DPQ { // only static methods } - public static final class EarlyReturn extends ControlFlowException { + public static final class EarlyReturn extends Exception { private static final long serialVersionUID = 1182697355931636213L; public final double result; private EarlyReturn(double result) { this.result = result; } + + @SuppressWarnings("sync-override") + @Override + public Throwable fillInStackTrace() { + return null; + } } // R >= 3.1.0: # define R_nonint(x) (fabs((x) - R_forceint(x)) > 1e-7) @@ -148,6 +153,13 @@ public final class DPQ { } } + // R_P_bounds_Inf_01 + public static void rpboundsinf01(double x, boolean lowerTail, boolean logP) throws EarlyReturn { + if (!Double.isFinite(x)) { + throw new EarlyReturn(x > 0 ? rdt0(lowerTail, logP) : rdt0(lowerTail, logP)); + } + } + // R_Q_P01_check public static void rqp01check(double p, boolean logP) throws EarlyReturn { if ((logP && p > 0) || (!logP && (p < 0 || p > 1))) { diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Exp.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Exp.java index 05c8188260..0d3e6ecbc3 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Exp.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Exp.java @@ -17,7 +17,11 @@ import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberPr import com.oracle.truffle.r.library.stats.StatsFunctions.Function2_1; import com.oracle.truffle.r.library.stats.StatsFunctions.Function2_2; -public class Exp { +public final class Exp { + private Exp() { + // only static members + } + public static final class DExp implements Function2_1 { @Override public double evaluate(double x, double scale, boolean giveLog) { @@ -37,9 +41,9 @@ public class Exp { } } - public static final class RExp implements RandFunction1_Double { + public static final class RExp extends RandFunction1_Double { @Override - public double evaluate(double scale, RandomNumberProvider rand) { + public double execute(double scale, RandomNumberProvider rand) { if (!Double.isFinite(scale) || scale <= 0.0) { return scale == 0. ? 0. : RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Geom.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Geom.java index 2d9c81f9f1..05443e198b 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Geom.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Geom.java @@ -35,7 +35,11 @@ import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberPr import com.oracle.truffle.r.library.stats.StatsFunctions.Function2_1; import com.oracle.truffle.r.library.stats.StatsFunctions.Function2_2; -public class Geom { +public final class Geom { + private Geom() { + // only static members + } + public static final class QGeom implements Function2_2 { @Override public double evaluate(double p, double prob, boolean lowerTail, boolean logP) { @@ -87,9 +91,9 @@ public class Geom { } } - public static final class RGeom implements RandFunction1_Double { + public static final class RGeom extends RandFunction1_Double { @Override - public double evaluate(double p, RandomNumberProvider rand) { + public double execute(double p, RandomNumberProvider rand) { if (!Double.isFinite(p) || p <= 0 || p > 1) { return RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/LogNormal.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/LogNormal.java index 824a4acef8..b95f76cf44 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/LogNormal.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/LogNormal.java @@ -23,15 +23,15 @@ public final class LogNormal { // only static members } - public static final class RLNorm implements RandFunction2_Double { + public static final class RLNorm extends RandFunction2_Double { private final Rnorm rnorm = new Rnorm(); @Override - public double evaluate(double meanlog, double sdlog, RandomNumberProvider rand) { + public double execute(double meanlog, double sdlog, RandomNumberProvider rand) { if (Double.isNaN(meanlog) || !Double.isFinite(sdlog) || sdlog < 0.) { return RMath.mlError(); } - return Math.exp(rnorm.evaluate(meanlog, sdlog, rand)); + return Math.exp(rnorm.execute(meanlog, sdlog, rand)); } } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java index d62ab89e7e..5e1ce8ccb6 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RBeta.java @@ -20,12 +20,12 @@ import static com.oracle.truffle.r.library.stats.RMath.fmin2; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class RBeta implements RandFunction2_Double { +public final class RBeta extends RandFunction2_Double { private static final double expmax = (DBL_MAX_EXP * M_LN2); /* = log(DBL_MAX) */ @Override - public double evaluate(double aa, double bb, RandomNumberProvider rand) { + public double execute(double aa, double bb, RandomNumberProvider rand) { if (Double.isNaN(aa) || Double.isNaN(bb) || aa < 0. || bb < 0.) { return RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RChisq.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RChisq.java index 67ba496753..975819f6bf 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RChisq.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RChisq.java @@ -14,16 +14,16 @@ package com.oracle.truffle.r.library.stats; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class RChisq implements RandFunction1_Double { +public final class RChisq extends RandFunction1_Double { public static double rchisq(double df, RandomNumberProvider rand) { if (!Double.isFinite(df) || df < 0.0) { return RMath.mlError(); } - return new RGamma().evaluate(df / 2.0, 2.0, rand); + return new RGamma().execute(df / 2.0, 2.0, rand); } @Override - public double evaluate(double a, RandomNumberProvider rand) { + public double execute(double a, RandomNumberProvider rand) { return rchisq(a, rand); } } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RGamma.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RGamma.java index eea9b5aea8..c5b40e24c4 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RGamma.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RGamma.java @@ -16,7 +16,7 @@ import static com.oracle.truffle.r.library.stats.TOMS708.fabs; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public class RGamma implements RandFunction2_Double { +public class RGamma extends RandFunction2_Double { private static final double sqrt32 = 5.656854; private static final double exp_m1 = 0.36787944117144232159; /* exp(-1) = 1/e */ @@ -41,7 +41,7 @@ public class RGamma implements RandFunction2_Double { private static final double a7 = 0.1233795; @Override - public double evaluate(double a, double scale, RandomNumberProvider rand) { + public double execute(double a, double scale, RandomNumberProvider rand) { // TODO: state variables double aa = 0.; diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RHyper.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RHyper.java index 7dba16bade..7af8863b92 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RHyper.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RHyper.java @@ -21,7 +21,7 @@ import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberPr import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; -public final class RHyper implements RandFunction3_Double { +public final class RHyper extends RandFunction3_Double { private static final double[] al = { 0.0, /* ln(0!)=ln(1) */ 0.0, /* ln(1!)=ln(1) */ @@ -87,7 +87,7 @@ public final class RHyper implements RandFunction3_Double { // rhyper(NR, NB, n) -- NR 'red', NB 'blue', n drawn, how many are 'red' @Override @TruffleBoundary - public double evaluate(double nn1in, double nn2in, double kkin, RandomNumberProvider rand) { + public double execute(double nn1in, double nn2in, double kkin, RandomNumberProvider rand) { /* extern double afc(int); */ int nn1; @@ -117,7 +117,7 @@ public final class RHyper implements RandFunction3_Double { // FIXME: Much faster to give rbinom() approx when appropriate; -> see Kuensch(1989) // Johnson, Kotz,.. p.258 (top) mention the *four* different binomial approximations if (kkin == 1.) { // Bernoulli - return rbinom.evaluate(kkin, nn1in / (nn1in + nn2in), rand); + return rbinom.execute(kkin, nn1in / (nn1in + nn2in), rand); } // Slow, but safe: return F^{-1}(U) where F(.) = phyper(.) and U ~ U[0,1] return QHyper.qhyper(rand.unifRand(), nn1in, nn2in, kkin, false, false); diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RLogis.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RLogis.java index 961dafe00d..f21be8e24c 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RLogis.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RLogis.java @@ -14,9 +14,9 @@ package com.oracle.truffle.r.library.stats; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class RLogis implements RandFunction2_Double { +public final class RLogis extends RandFunction2_Double { @Override - public double evaluate(double location, double scale, RandomNumberProvider rand) { + public double execute(double location, double scale, RandomNumberProvider rand) { if (Double.isNaN(location) || !Double.isFinite(scale)) { return RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinom.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinom.java index 7a49a79e36..8480bde3c0 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinom.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinom.java @@ -163,7 +163,7 @@ public abstract class RMultinom extends RExternalBuiltinNode.Arg3 { if (probK.compareTo(BigDecimal.ZERO) != 0) { pp = probK.divide(pTot, RoundingMode.HALF_UP).doubleValue(); // System.out.printf("[%d] %.17f\n", k + 1, pp); - rN[rnStartIdx + k] = ((pp < 1.) ? (int) rbinom.evaluate((double) n, pp, rand) : + rN[rnStartIdx + k] = ((pp < 1.) ? (int) rbinom.execute((double) n, pp, rand) : /* >= 1; > 1 happens because of rounding */ n); n -= rN[rnStartIdx + k]; diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNbinomMu.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNbinomMu.java index c095616fa1..03118361e2 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNbinomMu.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNbinomMu.java @@ -14,17 +14,17 @@ package com.oracle.truffle.r.library.stats; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class RNbinomMu implements RandFunction2_Double { +public final class RNbinomMu extends RandFunction2_Double { private final RGamma rgamma = new RGamma(); @Override - public double evaluate(double size, double mu, RandomNumberProvider rand) { + public double execute(double size, double mu, RandomNumberProvider rand) { if (!Double.isFinite(mu) || Double.isNaN(size) || size <= 0 || mu < 0) { return RMath.mlError(); } if (!Double.isFinite(size)) { size = Double.MAX_VALUE / 2.; } - return (mu == 0) ? 0 : RPois.rpois(rgamma.evaluate(size, mu / size, rand), rand); + return (mu == 0) ? 0 : RPois.rpois(rgamma.execute(size, mu / size, rand), rand); } } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java index 5e22bda3f5..3b17100cc5 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNchisq.java @@ -16,24 +16,24 @@ package com.oracle.truffle.r.library.stats; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class RNchisq implements RandFunction2_Double { +public final class RNchisq extends RandFunction2_Double { private final RGamma rgamma = new RGamma(); @Override - public double evaluate(double df, double lambda, RandomNumberProvider rand) { + public double execute(double df, double lambda, RandomNumberProvider rand) { if (!Double.isFinite(df) || !Double.isFinite(lambda) || df < 0. || lambda < 0.) { return RMath.mlError(); } if (lambda == 0.) { - return (df == 0.) ? 0. : rgamma.evaluate(df / 2., 2., rand); + return (df == 0.) ? 0. : rgamma.execute(df / 2., 2., rand); } else { double r = RPois.rpois(lambda / 2., rand); if (r > 0.) { r = RChisq.rchisq(2. * r, rand); } if (df > 0.) { - r += rgamma.evaluate(df / 2., 2., rand); + r += rgamma.execute(df / 2., 2., rand); } return r; } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java index 2638fa34de..397cd4f7a7 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RPois.java @@ -16,7 +16,7 @@ import static com.oracle.truffle.r.library.stats.MathConstants.M_1_SQRT_2PI; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class RPois implements RandFunction1_Double { +public final class RPois extends RandFunction1_Double { private static final double a0 = -0.5; private static final double a1 = 0.3333333; @@ -38,25 +38,25 @@ public final class RPois implements RandFunction1_Double { /* These are static --- persistent between calls for same mu : */ // TODO: state variables - static int l = 0; - static int m = 0; - static double b1; - static double b2; - static double c = 0; - static double c0 = 0; - static double c1 = 0; - static double c2 = 0; - static double c3 = 0; - static double[] pp = new double[36]; - static double p0 = 0; - static double p = 0; - static double q = 0; - static double s = 0; - static double d = 0; - static double omega = 0; - static double bigL = 0; /* integer "w/o overflow" */ - static double muprev = 0.; - static double muprev2 = 0.; /* , muold = 0. */ + private static int l = 0; + private static int m = 0; + private static double b1; + private static double b2; + private static double c = 0; + private static double c0 = 0; + private static double c1 = 0; + private static double c2 = 0; + private static double c3 = 0; + private static final double[] pp = new double[36]; + private static double p0 = 0; + private static double p = 0; + private static double q = 0; + private static double s = 0; + private static double d = 0; + private static double omega = 0; + private static double bigL = 0; /* integer "w/o overflow" */ + private static double muprev = 0.; + private static double muprev2 = 0.; /* , muold = 0. */ public static double rpois(double mu, RandomNumberProvider rand) { /* Local Vars [initialize some for -Wall]: */ @@ -275,7 +275,7 @@ public final class RPois implements RandFunction1_Double { } @Override - public double evaluate(double mu, RandomNumberProvider rand) { + public double execute(double mu, RandomNumberProvider rand) { return rpois(mu, rand); } } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RWeibull.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RWeibull.java index 3686f96591..638e406f54 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RWeibull.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RWeibull.java @@ -14,9 +14,9 @@ package com.oracle.truffle.r.library.stats; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class RWeibull implements RandFunction2_Double { +public final class RWeibull extends RandFunction2_Double { @Override - public double evaluate(double shape, double scale, RandomNumberProvider rand) { + public double execute(double shape, double scale, RandomNumberProvider rand) { if (!Double.isFinite(shape) || !Double.isFinite(scale) || shape <= 0. || scale <= 0.) { return scale == 0. ? 0. : RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java index b3548e5ea3..2bdaeeef56 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java @@ -22,16 +22,21 @@ import java.util.Arrays; import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.dsl.TypeSystemReference; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.LoopConditionProfile; -import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory.ConvertToLengthNodeGen; +import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory.RandFunction1NodeGen; +import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory.RandFunction2NodeGen; +import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory.RandFunction3NodeGen; +import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.nodes.unary.CastIntegerNode; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDouble; @@ -43,6 +48,11 @@ import com.oracle.truffle.r.runtime.rng.RRNG; import com.oracle.truffle.r.runtime.rng.RRNG.NormKind; import com.oracle.truffle.r.runtime.rng.RandomNumberGenerator; +/** + * Contains infrastructure for R external functions implementing generation of a random value from + * given random value distribution. To implement such external function, implement one of: + * {@link RandFunction3_Double}, {@link RandFunction2_Double} or {@link RandFunction1_Double}. + */ public final class RandGenerationFunctions { @CompilationFinal private static final RDouble DUMMY_VECTOR = RDouble.valueOf(1); @@ -51,14 +61,22 @@ public final class RandGenerationFunctions { } public static final class RandomNumberProvider { - private final RandomNumberGenerator generator; - private final NormKind normKind; + final RandomNumberGenerator generator; + final NormKind normKind; public RandomNumberProvider(RandomNumberGenerator generator, NormKind normKind) { this.generator = generator; this.normKind = normKind; } + public static RandomNumberProvider fromCurrentRNG() { + return new RandomNumberProvider(RRNG.currentGenerator(), RRNG.currentNormKind()); + } + + protected boolean isSame(RandomNumberProvider other) { + return this.generator == other.generator && this.normKind == other.normKind; + } + public double unifRand() { return generator.genrandDouble(); } @@ -74,135 +92,32 @@ public final class RandGenerationFunctions { // inspired by the DEFRAND{X}_REAL and DEFRAND{X}_INT macros in GnuR - public interface RandFunction3_Double { - double evaluate(double a, double b, double c, RandomNumberProvider rand); + public abstract static class RandFunction3_Double extends Node { + public abstract double execute(double a, double b, double c, RandomNumberProvider rand); } - public interface RandFunction2_Double extends RandFunction3_Double { - double evaluate(double a, double b, RandomNumberProvider rand); + public abstract static class RandFunction2_Double extends RandFunction3_Double { + public abstract double execute(double a, double b, RandomNumberProvider rand); @Override - default double evaluate(double a, double b, double c, RandomNumberProvider rand) { - return evaluate(a, b, rand); + public final double execute(double a, double b, double c, RandomNumberProvider rand) { + return execute(a, b, rand); } } - public interface RandFunction1_Double extends RandFunction2_Double { - double evaluate(double a, RandomNumberProvider rand); + public abstract static class RandFunction1_Double extends RandFunction2_Double { + public abstract double execute(double a, RandomNumberProvider rand); @Override - default double evaluate(double a, double b, RandomNumberProvider rand) { - return evaluate(a, rand); - } - } - - static final class RandGenerationProfiles { - final BranchProfile nanResult = BranchProfile.create(); - final BranchProfile nan = BranchProfile.create(); - final VectorLengthProfile resultVectorLengthProfile = VectorLengthProfile.create(); - final LoopConditionProfile loopConditionProfile = LoopConditionProfile.createCountingProfile(); - private final ValueProfile randClassProfile = ValueProfile.createClassProfile(); - private final ValueProfile generatorProfile = ValueProfile.createIdentityProfile(); - private final ValueProfile normKindProfile = ValueProfile.createEqualityProfile(); - - public static RandGenerationProfiles create() { - return new RandGenerationProfiles(); - } - - public RandomNumberProvider createRandProvider() { - return new RandomNumberProvider(randClassProfile.profile(generatorProfile.profile(RRNG.currentGenerator())), normKindProfile.profile(RRNG.currentNormKind())); - } - } - - private static RAbstractIntVector evaluate3Int(Node node, RandFunction3_Double function, int lengthIn, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, - RandGenerationProfiles profiles) { - int length = lengthIn; - int aLength = a.getLength(); - int bLength = b.getLength(); - int cLength = c.getLength(); - if (aLength == 0 || bLength == 0 || cLength == 0) { - profiles.nanResult.enter(); - RError.warning(SHOW_CALLER, RError.Message.NA_PRODUCED); - int[] nansResult = new int[length]; - Arrays.fill(nansResult, RRuntime.INT_NA); - return RDataFactory.createIntVector(nansResult, false); - } - - length = profiles.resultVectorLengthProfile.profile(length); - RNode.reportWork(node, length); - boolean nans = false; - int[] result = new int[length]; - RRNG.getRNGState(); - RandomNumberProvider rand = profiles.createRandProvider(); - profiles.loopConditionProfile.profileCounted(length); - for (int i = 0; profiles.loopConditionProfile.inject(i < length); i++) { - double aValue = a.getDataAt(i % aLength); - double bValue = b.getDataAt(i % bLength); - double cValue = c.getDataAt(i % cLength); - double value = function.evaluate(aValue, bValue, cValue, rand); - if (Double.isNaN(value) || value < Integer.MIN_VALUE || value > Integer.MAX_VALUE) { - profiles.nan.enter(); - nans = true; - result[i] = RRuntime.INT_NA; - } else { - result[i] = (int) value; - } - } - RRNG.putRNGState(); - if (nans) { - RError.warning(SHOW_CALLER, RError.Message.NA_PRODUCED); - } - return RDataFactory.createIntVector(result, !nans); - } - - private static RAbstractDoubleVector evaluate3Double(Node node, RandFunction3_Double function, int lengthIn, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, - RandGenerationProfiles profiles) { - int length = lengthIn; - int aLength = a.getLength(); - int bLength = b.getLength(); - int cLength = c.getLength(); - if (aLength == 0 || bLength == 0 || cLength == 0) { - profiles.nanResult.enter(); - RError.warning(SHOW_CALLER, RError.Message.NA_PRODUCED); - return createVectorOf(length, RRuntime.DOUBLE_NA); - } - - length = profiles.resultVectorLengthProfile.profile(length); - RNode.reportWork(node, length); - boolean nans = false; - double[] result; - result = new double[length]; - RRNG.getRNGState(); - RandomNumberProvider rand = profiles.createRandProvider(); - profiles.loopConditionProfile.profileCounted(length); - for (int i = 0; profiles.loopConditionProfile.inject(i < length); i++) { - double aValue = a.getDataAt(i % aLength); - double bValue = b.getDataAt(i % bLength); - double cValue = c.getDataAt(i % cLength); - double value = function.evaluate(aValue, bValue, cValue, rand); - if (Double.isNaN(value) || RRuntime.isNA(value)) { - profiles.nan.enter(); - nans = true; - } - result[i] = value; + public final double execute(double a, double b, RandomNumberProvider rand) { + return execute(a, rand); } - RRNG.putRNGState(); - if (nans) { - RError.warning(SHOW_CALLER, RError.Message.NA_PRODUCED); - } - return RDataFactory.createDoubleVector(result, !nans); - } - - private static RAbstractDoubleVector createVectorOf(int length, double element) { - double[] nansResult = new double[length]; - Arrays.fill(nansResult, element); - return RDataFactory.createDoubleVector(nansResult, false); } /** * Converts given value to actual length that should be used as length of the output vector. The * argument must be cast using {@link #addLengthCast(CastBuilder)}. Using this node allows us to - * avoid casting of long vectors to integers, if we only need to know their length. + * avoid casting of long vectors to integers if we only need to know their length. */ protected abstract static class ConvertToLength extends Node { public abstract int execute(RAbstractVector value); @@ -229,112 +144,242 @@ public final class RandGenerationFunctions { } } - public abstract static class Function3_IntNode extends RExternalBuiltinNode.Arg4 { - private final RandFunction3_Double function; + /** + * Executor node handles the validation, the loop over all vector elements, the creation of the + * result vector, and similar. The random function is provided as implementation of + * {@link RandFunction3_Double}. + */ + @TypeSystemReference(EmptyTypeSystemFlatLayout.class) + protected abstract static class RandFunctionExecutorBase extends Node { + static final class RandGenerationNodeData { + final BranchProfile nanResult = BranchProfile.create(); + final BranchProfile nan = BranchProfile.create(); + final VectorLengthProfile resultVectorLengthProfile = VectorLengthProfile.create(); + final LoopConditionProfile loopConditionProfile = LoopConditionProfile.createCountingProfile(); + + public static RandGenerationNodeData create() { + return new RandGenerationNodeData(); + } + } + + public abstract Object execute(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand); + @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); - protected Function3_IntNode(RandFunction3_Double function) { - this.function = function; + @Specialization(guards = {"randCached.isSame(rand)"}) + protected final Object evaluateWithCached(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, + @SuppressWarnings("unused") RandomNumberProvider rand, + @Cached("rand") RandomNumberProvider randCached, + @Cached("create()") RandGenerationNodeData nodeData) { + return evaluateWrapper(lengthVec, a, b, c, randCached, nodeData); } - @Override - protected void createCasts(CastBuilder casts) { - ConvertToLength.addLengthCast(casts); - casts.arg(1).asDoubleVector(); - casts.arg(2).asDoubleVector(); - casts.arg(3).asDoubleVector(); + @Specialization(contains = "evaluateWithCached") + protected final Object evaluateFallback(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand, + @Cached("create()") RandGenerationNodeData nodeData) { + return evaluateWrapper(lengthVec, a, b, c, rand, nodeData); } - @Specialization - protected RAbstractIntVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, - @Cached("create()") RandGenerationProfiles profiles) { - return evaluate3Int(this, function, convertToLength.execute(length), a, b, c, profiles); + private Object evaluateWrapper(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand, + RandGenerationNodeData nodeData) { + int length = nodeData.resultVectorLengthProfile.profile(convertToLength.execute(lengthVec)); + RNode.reportWork(this, length); + Object result = evaluate(length, a, b, c, nodeData, rand); + return result; + } + + Object evaluate(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandGenerationNodeData nodeData, RandomNumberProvider randProvider) { + // DSL generates code for this class too, with abstract method it would not compile + throw RInternalError.shouldNotReachHere("must be overridden"); + } + + static void putRNGState() { + // Note: we call putRNGState only if we actually changed the state, i.e. called random + // number generation. We do not need to getRNGState() because the parent wrapper node + // should do that for us + RRNG.putRNGState(); + } + + static void showNAWarning() { + RError.warning(SHOW_CALLER, RError.Message.NA_PRODUCED); } } - public abstract static class Function2_IntNode extends RExternalBuiltinNode.Arg3 { - private final RandFunction2_Double function; - @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + @TypeSystemReference(EmptyTypeSystemFlatLayout.class) + protected abstract static class RandFunctionIntExecutorNode extends RandFunctionExecutorBase { + @Child private RandFunction3_Double function; - protected Function2_IntNode(RandFunction2_Double function) { + protected RandFunctionIntExecutorNode(RandFunction3_Double function) { this.function = function; } @Override - protected void createCasts(CastBuilder casts) { - ConvertToLength.addLengthCast(casts); - casts.arg(1).asDoubleVector(); - casts.arg(2).asDoubleVector(); - } + protected RAbstractIntVector evaluate(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandGenerationNodeData nodeData, + RandomNumberProvider randProvider) { + int aLength = a.getLength(); + int bLength = b.getLength(); + int cLength = c.getLength(); + if (aLength == 0 || bLength == 0 || cLength == 0) { + nodeData.nanResult.enter(); + showNAWarning(); + int[] nansResult = new int[length]; + Arrays.fill(nansResult, RRuntime.INT_NA); + return RDataFactory.createIntVector(nansResult, false); + } - @Specialization - protected RAbstractIntVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, - @Cached("create()") RandGenerationProfiles profiles) { - return evaluate3Int(this, function, convertToLength.execute(length), a, b, DUMMY_VECTOR, profiles); + boolean nans = false; + int[] result = new int[length]; + nodeData.loopConditionProfile.profileCounted(length); + for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { + double aValue = a.getDataAt(i % aLength); + double bValue = b.getDataAt(i % bLength); + double cValue = c.getDataAt(i % cLength); + double value = function.execute(aValue, bValue, cValue, randProvider); + if (Double.isNaN(value) || value < Integer.MIN_VALUE || value > Integer.MAX_VALUE) { + nodeData.nan.enter(); + nans = true; + result[i] = RRuntime.INT_NA; + } else { + result[i] = (int) value; + } + } + putRNGState(); + if (nans) { + showNAWarning(); + } + return RDataFactory.createIntVector(result, !nans); } } - public abstract static class Function1_IntNode extends RExternalBuiltinNode.Arg2 { - private final RandFunction1_Double function; - @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + @TypeSystemReference(EmptyTypeSystemFlatLayout.class) + protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionExecutorBase { + @Child private RandFunction3_Double function; - protected Function1_IntNode(RandFunction1_Double function) { + protected RandFunctionDoubleExecutorNode(RandFunction3_Double function) { this.function = function; } + @Override + protected RAbstractDoubleVector evaluate(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandGenerationNodeData nodeData, + RandomNumberProvider randProvider) { + int aLength = a.getLength(); + int bLength = b.getLength(); + int cLength = c.getLength(); + if (aLength == 0 || bLength == 0 || cLength == 0) { + nodeData.nanResult.enter(); + showNAWarning(); + double[] nansResult = new double[length]; + Arrays.fill(nansResult, RRuntime.DOUBLE_NA); + return RDataFactory.createDoubleVector(nansResult, false); + } + + boolean nans = false; + double[] result; + result = new double[length]; + nodeData.loopConditionProfile.profileCounted(length); + for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { + double aValue = a.getDataAt(i % aLength); + double bValue = b.getDataAt(i % bLength); + double cValue = c.getDataAt(i % cLength); + double value = function.execute(aValue, bValue, cValue, randProvider); + if (Double.isNaN(value) || RRuntime.isNA(value)) { + nodeData.nan.enter(); + nans = true; + } + result[i] = value; + } + putRNGState(); + if (nans) { + showNAWarning(); + } + return RDataFactory.createDoubleVector(result, !nans); + } + } + + public abstract static class RandFunction3Node extends RExternalBuiltinNode.Arg4 { + @Child private RandFunctionExecutorBase inner; + + protected RandFunction3Node(RandFunctionExecutorBase inner) { + this.inner = inner; + } + + public static RandFunction3Node createInt(RandFunction3_Double function) { + return RandFunction3NodeGen.create(RandGenerationFunctionsFactory.RandFunctionIntExecutorNodeGen.create(function)); + } + + public static RandFunction3Node createDouble(RandFunction3_Double function) { + return RandFunction3NodeGen.create(RandGenerationFunctionsFactory.RandFunctionDoubleExecutorNodeGen.create(function)); + } + @Override protected void createCasts(CastBuilder casts) { ConvertToLength.addLengthCast(casts); casts.arg(1).asDoubleVector(); + casts.arg(2).asDoubleVector(); + casts.arg(3).asDoubleVector(); } @Specialization - protected RAbstractIntVector evaluate(RAbstractVector length, RAbstractDoubleVector a, - @Cached("create()") RandGenerationProfiles profiles) { - return evaluate3Int(this, function, convertToLength.execute(length), a, DUMMY_VECTOR, DUMMY_VECTOR, profiles); + protected Object evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c) { + RRNG.getRNGState(); + return inner.execute(length, a, b, c, RandomNumberProvider.fromCurrentRNG()); } } - public abstract static class Function1_DoubleNode extends RExternalBuiltinNode.Arg2 { - private final RandFunction1_Double function; - @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + public abstract static class RandFunction2Node extends RExternalBuiltinNode.Arg3 { + @Child private RandFunctionExecutorBase inner; - protected Function1_DoubleNode(RandFunction1_Double function) { - this.function = function; + protected RandFunction2Node(RandFunctionExecutorBase inner) { + this.inner = inner; + } + + public static RandFunction2Node createInt(RandFunction2_Double function) { + return RandFunction2NodeGen.create(RandGenerationFunctionsFactory.RandFunctionIntExecutorNodeGen.create(function)); + } + + public static RandFunction2Node createDouble(RandFunction2_Double function) { + return RandFunction2NodeGen.create(RandGenerationFunctionsFactory.RandFunctionDoubleExecutorNodeGen.create(function)); } @Override protected void createCasts(CastBuilder casts) { ConvertToLength.addLengthCast(casts); casts.arg(1).asDoubleVector(); + casts.arg(2).asDoubleVector(); } @Specialization - protected RAbstractDoubleVector evaluate(RAbstractVector length, RAbstractDoubleVector a, - @Cached("create()") RandGenerationProfiles profiles) { - return evaluate3Double(this, function, convertToLength.execute(length), a, DUMMY_VECTOR, DUMMY_VECTOR, profiles); + protected Object evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b) { + RRNG.getRNGState(); + return inner.execute(length, a, b, DUMMY_VECTOR, RandomNumberProvider.fromCurrentRNG()); } } - public abstract static class Function2_DoubleNode extends RExternalBuiltinNode.Arg3 { - private final RandFunction2_Double function; - @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + public abstract static class RandFunction1Node extends RExternalBuiltinNode.Arg2 { + @Child private RandFunctionExecutorBase inner; - protected Function2_DoubleNode(RandFunction2_Double function) { - this.function = function; + protected RandFunction1Node(RandFunctionExecutorBase inner) { + this.inner = inner; + } + + public static RandFunction1Node createInt(RandFunction1_Double function) { + return RandFunction1NodeGen.create(RandGenerationFunctionsFactory.RandFunctionIntExecutorNodeGen.create(function)); + } + + public static RandFunction1Node createDouble(RandFunction1_Double function) { + return RandFunction1NodeGen.create(RandGenerationFunctionsFactory.RandFunctionDoubleExecutorNodeGen.create(function)); } @Override protected void createCasts(CastBuilder casts) { ConvertToLength.addLengthCast(casts); casts.arg(1).asDoubleVector(); - casts.arg(2).asDoubleVector(); } @Specialization - protected RAbstractDoubleVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, - @Cached("create()") RandGenerationProfiles profiles) { - return evaluate3Double(this, function, convertToLength.execute(length), a, b, DUMMY_VECTOR, profiles); + protected Object evaluate(RAbstractVector length, RAbstractDoubleVector a) { + RRNG.getRNGState(); + return inner.execute(length, a, DUMMY_VECTOR, DUMMY_VECTOR, RandomNumberProvider.fromCurrentRNG()); } } } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java index db6545e1e1..edd2864e79 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java @@ -18,12 +18,12 @@ import com.oracle.truffle.r.runtime.RRuntime; // transcribed from rbinom.c -public final class Rbinom implements RandFunction2_Double { +public final class Rbinom extends RandFunction2_Double { private final Qbinom qbinom = new Qbinom(); @Override - public double evaluate(double nin, double pp, RandomNumberProvider rand) { + public double execute(double nin, double pp, RandomNumberProvider rand) { double psave = -1.0; int nsave = -1; diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rf.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rf.java index 21cbb6d2bc..c753f1777f 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rf.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rf.java @@ -14,9 +14,9 @@ package com.oracle.truffle.r.library.stats; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class Rf implements RandFunction2_Double { +public final class Rf extends RandFunction2_Double { @Override - public double evaluate(double n1, double n2, RandomNumberProvider rand) { + public double execute(double n1, double n2, RandomNumberProvider rand) { if (Double.isNaN(n1) || Double.isNaN(n2) || n1 <= 0. || n2 <= 0.) { return RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rnorm.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rnorm.java index 1996274e87..be60d6863d 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rnorm.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rnorm.java @@ -11,16 +11,27 @@ */ package com.oracle.truffle.r.library.stats; +import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class Rnorm implements RandFunction2_Double { +public final class Rnorm extends RandFunction2_Double { + private final BranchProfile errorProfile = BranchProfile.create(); + private final ConditionProfile zeroSigmaProfile = ConditionProfile.createBinaryProfile(); + private final ValueProfile sigmaValueProfile = ValueProfile.createEqualityProfile(); + private final ValueProfile muValueProfile = ValueProfile.createEqualityProfile(); + @Override - public double evaluate(double mu, double sigma, RandomNumberProvider rand) { + public double execute(double muIn, double sigmaIn, RandomNumberProvider rand) { + double sigma = sigmaValueProfile.profile(sigmaIn); + double mu = muValueProfile.profile(muIn); if (Double.isNaN(mu) || !Double.isFinite(sigma) || sigma < 0.) { + errorProfile.enter(); return RMath.mlError(); } - if (sigma == 0. || !Double.isFinite(mu)) { + if (zeroSigmaProfile.profile(sigma == 0. || !Double.isFinite(mu))) { return mu; /* includes mu = +/- Inf with finite sigma */ } else { return mu + sigma * rand.normRand(); diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rt.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rt.java index 51a266499e..80c8ea59fc 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rt.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rt.java @@ -16,9 +16,9 @@ import static com.oracle.truffle.r.library.stats.RChisq.rchisq; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; -public final class Rt implements RandFunction1_Double { +public final class Rt extends RandFunction1_Double { @Override - public double evaluate(double df, RandomNumberProvider rand) { + public double execute(double df, RandomNumberProvider rand) { if (Double.isNaN(df) || df <= 0.0) { return RMath.mlError(); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Runif.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Runif.java index 5701011cf5..d5729f4a3b 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Runif.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Runif.java @@ -22,17 +22,28 @@ */ package com.oracle.truffle.r.library.stats; +import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double; import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; import com.oracle.truffle.r.runtime.RRuntime; -public final class Runif implements RandFunction2_Double { +public final class Runif extends RandFunction2_Double { + private final BranchProfile errorProfile = BranchProfile.create(); + private final ConditionProfile minEqualsMaxProfile = ConditionProfile.createBinaryProfile(); + private final ValueProfile minValueProfile = ValueProfile.createEqualityProfile(); + private final ValueProfile maxValueProfile = ValueProfile.createEqualityProfile(); + @Override - public double evaluate(double min, double max, RandomNumberProvider rand) { + public double execute(double minIn, double maxIn, RandomNumberProvider rand) { + double min = minValueProfile.profile(minIn); + double max = maxValueProfile.profile(maxIn); if (!RRuntime.isFinite(min) || !RRuntime.isFinite(max) || max < min) { + errorProfile.enter(); return RMath.mlError(); } - if (min == max) { + if (minEqualsMaxProfile.profile(min == max)) { return min; } return min + rand.unifRand() * (max - min); diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Signrank.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Signrank.java index 1b33cb4ddb..a6747c5660 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Signrank.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Signrank.java @@ -17,10 +17,13 @@ import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1_ import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider; public final class Signrank { + private Signrank() { + // only static members + } - public static final class RSignrank implements RandFunction1_Double { + public static final class RSignrank extends RandFunction1_Double { @Override - public double evaluate(double n, RandomNumberProvider rand) { + public double execute(double n, RandomNumberProvider rand) { int i; int k; double r; diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Wilcox.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Wilcox.java index 0064741144..acf8b5555f 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Wilcox.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Wilcox.java @@ -19,10 +19,13 @@ import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; public final class Wilcox { + private Wilcox() { + // only static members + } - public static final class RWilcox implements RandFunction2_Double { + public static final class RWilcox extends RandFunction2_Double { @Override - public double evaluate(double m, double n, RandomNumberProvider rand) { + public double execute(double m, double n, RandomNumberProvider rand) { int i; int j; int k; diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java index 279790c9b7..2c2f2c7983 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java @@ -85,7 +85,9 @@ import com.oracle.truffle.r.library.stats.RNchisq; import com.oracle.truffle.r.library.stats.RPois; import com.oracle.truffle.r.library.stats.RWeibull; import com.oracle.truffle.r.library.stats.RandGenerationFunctions; -import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory; +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1Node; +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2Node; +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction3Node; import com.oracle.truffle.r.library.stats.Rbinom; import com.oracle.truffle.r.library.stats.Rf; import com.oracle.truffle.r.library.stats.Rnorm; @@ -249,41 +251,41 @@ public class CallAndExternalFunctions { case "qnorm": return StatsFunctionsFactory.Function3_2NodeGen.create(new Qnorm()); case "rnorm": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new Rnorm()); + return RandFunction2Node.createDouble(new Rnorm()); case "runif": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new Runif()); + return RandFunction2Node.createDouble(new Runif()); case "rbeta": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RBeta()); + return RandFunction2Node.createDouble(new RBeta()); case "rgamma": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RGamma()); + return RandFunction2Node.createDouble(new RGamma()); case "rcauchy": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RCauchy()); + return RandFunction2Node.createDouble(new RCauchy()); case "rf": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new Rf()); + return RandFunction2Node.createDouble(new Rf()); case "rlogis": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RLogis()); + return RandFunction2Node.createDouble(new RLogis()); case "rweibull": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RWeibull()); + return RandFunction2Node.createDouble(new RWeibull()); case "rnchisq": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RNchisq()); + return RandFunction2Node.createDouble(new RNchisq()); case "rnbinom_mu": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new RNbinomMu()); + return RandFunction2Node.createDouble(new RNbinomMu()); case "rwilcox": - return RandGenerationFunctionsFactory.Function2_IntNodeGen.create(new RWilcox()); + return RandFunction2Node.createInt(new RWilcox()); case "rchisq": - return RandGenerationFunctionsFactory.Function1_DoubleNodeGen.create(new RChisq()); + return RandFunction1Node.createDouble(new RChisq()); case "rexp": - return RandGenerationFunctionsFactory.Function1_DoubleNodeGen.create(new RExp()); + return RandFunction1Node.createDouble(new RExp()); case "rgeom": - return RandGenerationFunctionsFactory.Function1_IntNodeGen.create(new RGeom()); + return RandFunction1Node.createInt(new RGeom()); case "rpois": - return RandGenerationFunctionsFactory.Function1_IntNodeGen.create(new RPois()); + return RandFunction1Node.createInt(new RPois()); case "rt": - return RandGenerationFunctionsFactory.Function1_DoubleNodeGen.create(new Rt()); + return RandFunction1Node.createDouble(new Rt()); case "rsignrank": - return RandGenerationFunctionsFactory.Function1_IntNodeGen.create(new RSignrank()); + return RandFunction1Node.createInt(new RSignrank()); case "rhyper": - return RandGenerationFunctionsFactory.Function3_IntNodeGen.create(new RHyper()); + return RandFunction3Node.createInt(new RHyper()); case "qgamma": return StatsFunctionsFactory.Function3_2NodeGen.create(new QgammaFunc()); case "dbinom": @@ -291,7 +293,7 @@ public class CallAndExternalFunctions { case "qbinom": return StatsFunctionsFactory.Function3_2NodeGen.create(new Qbinom()); case "rbinom": - return RandGenerationFunctionsFactory.Function2_IntNodeGen.create(new Rbinom()); + return RandFunction2Node.createInt(new Rbinom()); case "pbinom": return StatsFunctionsFactory.Function3_2NodeGen.create(new Pbinom()); case "pbeta": @@ -329,7 +331,7 @@ public class CallAndExternalFunctions { case "dt": return StatsFunctionsFactory.Function2_1NodeGen.create(new Dt()); case "rlnorm": - return RandGenerationFunctionsFactory.Function2_DoubleNodeGen.create(new LogNormal.RLNorm()); + return RandFunction2Node.createDouble(new LogNormal.RLNorm()); case "dlnorm": return StatsFunctionsFactory.Function3_1NodeGen.create(new DLNorm()); case "qlnorm": -- GitLab