Skip to content
Snippets Groups Projects
Commit 5913bafe authored by stepan's avatar stepan
Browse files

Refactor common quantile search code from QPois and QBinom + tests

parent d513ed12
No related branches found
No related tags found
No related merge requests found
Showing
with 805 additions and 693 deletions
...@@ -41,7 +41,6 @@ public final class QPois implements Function2_2 { ...@@ -41,7 +41,6 @@ public final class QPois implements Function2_2 {
return e.result; return e.result;
} }
double mu = lambda;
double sigma = Math.sqrt(lambda); double sigma = Math.sqrt(lambda);
/* gamma = sigma; PR#8058 should be kurtosis which is mu^-0.5 */ /* gamma = sigma; PR#8058 should be kurtosis which is mu^-0.5 */
double gamma = 1.0 / sigma; double gamma = 1.0 / sigma;
...@@ -70,61 +69,18 @@ public final class QPois implements Function2_2 { ...@@ -70,61 +69,18 @@ public final class QPois implements Function2_2 {
// #ifdef HAVE_NEARBYINT // #ifdef HAVE_NEARBYINT
// y = nearbyint(mu + sigma * (z + gamma * (z*z - 1) / 6)); // y = nearbyint(mu + sigma * (z + gamma * (z*z - 1) / 6));
// #else // #else
double y = Math.round(mu + sigma * (z + gamma * (z * z - 1) / 6)); double y = RMath.round(lambda + sigma * (z + gamma * (z * z - 1) / 6));
z = ppois.evaluate(y, lambda, /* lower_tail */true, /* log_p */false);
/* fuzz to ensure left continuity; 1 - 1e-7 may lose too much : */ /* fuzz to ensure left continuity; 1 - 1e-7 may lose too much : */
p *= 1 - 64 * DBL_EPSILON; p *= 1 - 64 * DBL_EPSILON;
/* If the mean is not too large a simple search is OK */ QuantileSearch search = new QuantileSearch((quantile, lt, lp) -> ppois.evaluate(quantile, lambda, lt, lp));
if (lambda < 1e5) { if (lambda < 1e5) {
return search(y, z, p, lambda, 1).y; /* If the mean is not too large a simple search is OK */
return search.simpleSearch(y, p, 1);
} else { } else {
/* Otherwise be a bit cleverer in the search */ /* Otherwise be a bit cleverer in the search */
double incr = Math.floor(y * 0.001); return search.iterativeSearch(y, p);
double oldincr;
do {
oldincr = incr;
SearchResult searchResult = search(y, z, p, lambda, incr);
y = searchResult.y;
z = searchResult.z;
incr = RMath.fmax2(1, Math.floor(incr / 100));
} while (oldincr > 1 && incr > lambda * 1e-15);
return y;
}
}
private SearchResult search(double yIn, double zIn, double p, double lambda, double incr) {
if (zIn >= p) {
/* search to the left */
double y = yIn;
for (;;) {
double z = zIn;
if (y == 0 || (z = ppois.evaluate(y - incr, lambda, /* l._t. */true, /* log_p */false)) < p) {
return new SearchResult(y, z);
}
y = RMath.fmax2(0, y - incr);
}
} else { /* search to the right */
double y = yIn;
for (;;) {
y = y + incr;
double z;
if ((z = ppois.evaluate(y, lambda, /* l._t. */true, /* log_p */false)) >= p) {
return new SearchResult(y, z);
}
}
}
}
private static final class SearchResult {
final double y;
final double z;
SearchResult(double y, double z) {
this.y = y;
this.z = z;
} }
} }
} }
...@@ -19,42 +19,9 @@ import com.oracle.truffle.r.runtime.RRuntime; ...@@ -19,42 +19,9 @@ import com.oracle.truffle.r.runtime.RRuntime;
// transcribed from qbinom.c // transcribed from qbinom.c
public final class Qbinom implements StatsFunctions.Function3_2 { public final class Qbinom implements StatsFunctions.Function3_2 {
private static final class Search {
private double z;
Search(double z) {
this.z = z;
}
double doSearch(double initialY, double p, double n, double pr, double incr, Pbinom pbinom1, Pbinom pbinom2) {
double y = initialY;
if (z >= p) {
/* search to the left */
for (;;) {
double newz;
if (y == 0 || (newz = pbinom1.evaluate(y - incr, n, pr, true, false)) < p) {
return y;
}
y = Math.max(0, y - incr);
z = newz;
}
} else { /* search to the right */
for (;;) {
y = Math.min(y + incr, n);
if (y == n || (z = pbinom2.evaluate(y, n, pr, true, false)) >= p) {
return y;
}
}
}
}
}
private final BranchProfile nanProfile = BranchProfile.create(); private final BranchProfile nanProfile = BranchProfile.create();
private final ConditionProfile smallNProfile = ConditionProfile.createBinaryProfile(); private final ConditionProfile smallNProfile = ConditionProfile.createBinaryProfile();
private final Pbinom pbinom = new Pbinom(); private final Pbinom pbinom = new Pbinom();
private final Pbinom pbinomSearch1 = new Pbinom();
private final Pbinom pbinomSearch2 = new Pbinom();
@Override @Override
public double evaluate(double initialP, double n, double pr, boolean lowerTail, boolean logProb) { public double evaluate(double initialP, double n, double pr, boolean lowerTail, boolean logProb) {
...@@ -149,24 +116,14 @@ public final class Qbinom implements StatsFunctions.Function3_2 { ...@@ -149,24 +116,14 @@ public final class Qbinom implements StatsFunctions.Function3_2 {
y = n; y = n;
} }
z = pbinom.evaluate(y, n, pr, /* lowerTail */true, /* logP */false);
/* fuzz to ensure left continuity: */ /* fuzz to ensure left continuity: */
p *= 1 - 64 * RRuntime.EPSILON; p *= 1 - 64 * RRuntime.EPSILON;
Search search = new Search(z); QuantileSearch search = new QuantileSearch(n, (quantile, lt, lp) -> pbinom.evaluate(quantile, n, pr, lt, lp));
if (smallNProfile.profile(n < 1e5)) { if (smallNProfile.profile(n < 1e5)) {
return search.doSearch(y, p, n, pr, 1, pbinomSearch1, pbinomSearch2); return search.simpleSearch(y, p, 1);
} else {
return search.iterativeSearch(y, p, Math.floor(n * 0.001), 1e-15, 100);
} }
/* Otherwise be a bit cleverer in the search */
double incr = Math.floor(n * 0.001);
double oldincr;
do {
oldincr = incr;
y = search.doSearch(y, p, n, pr, incr, pbinomSearch1, pbinomSearch2);
incr = Math.max(1, Math.floor(incr / 100));
} while (oldincr > 1 && incr > n * 1e-15);
return y;
} }
} }
/*
* This material is distributed under the GNU General Public License
* Version 2. You may review the terms of this license at
* http://www.gnu.org/licenses/gpl-2.0.html
*
* Copyright (C) 1998 Ross Ihaka
* Copyright (c) 2000-2016, The R Core Team
* Copyright (c) 2003-2016, The R Foundation
* Copyright (c) 2016, Oracle and/or its affiliates
*
* All rights reserved.
*/
package com.oracle.truffle.r.library.stats;
/**
* Searches for a quantile of given random variable using it's distribution function. The search
* takes steps of given size until it reaches the quantile or until it steps over it. This class and
* its {@code {@link #simpleSearch(double, double, double)}} method correspond to several
* {@code do_search} functions in GnuR.
*/
public final class QuantileSearch {
/**
* This is the value of the distribution function where the search finished last time.
*/
private double z;
private final double rightSearchLimit;
private final DistributionFunc distributionFunc;
/**
* @param rightSearchLimit If set to non-negative value, then the search to the right will be
* limited by it
* @param distributionFunc The distribution function, all parameters except the quantile,
* lowerTail, and logP are fixed.
*/
public QuantileSearch(double rightSearchLimit, DistributionFunc distributionFunc) {
this.rightSearchLimit = rightSearchLimit;
this.distributionFunc = distributionFunc;
}
/**
* Constructs the object without {@code rightSearchLimit}.
*/
public QuantileSearch(DistributionFunc distributionFunc) {
this.rightSearchLimit = -1;
this.distributionFunc = distributionFunc;
}
public double simpleSearch(double yIn, double p, double incr) {
z = distributionFunc.eval(yIn, true, false);
return search(yIn, p, incr);
}
/**
* Invokes {@link #simpleSearch(double, double, double)} iteratively dividing the increment step
* by {@code incrDenominator} until the step is greater than the result times the
* {@code resultFactor}, then the result is deemed 'close enough' and returned.
*
* @param initialY where to start the search (quantile)
* @param p the target of the search (probability)
* @param initialIncr initial value for the increment step
* @param resultFactor see the method doc.
* @param incrDenominator see the method doc.
* @return the quantile (number close to it) for {@code p}.
*/
public double iterativeSearch(double initialY, double p, double initialIncr, double resultFactor, double incrDenominator) {
assert initialIncr > 0. : "initialIncr zero or negative. Maybe result of too small initialY?";
double result;
double oldIncr;
double incr = initialIncr;
z = distributionFunc.eval(initialY, true, false);
do {
oldIncr = incr;
result = search(initialY, p, incr);
incr = RMath.fmax2(1, Math.floor(incr / incrDenominator));
} while (oldIncr > 1 && incr > result * resultFactor);
return result;
}
/**
* The same as {@link #iterativeSearch(double, double, double, double, double)}, but with
* default values for the missing parameters.
*/
public double iterativeSearch(double initialY, double p) {
return iterativeSearch(initialY, p, Math.floor(initialY * 0.001), 1e-15, 100);
}
private double search(double yIn, double p, double incr) {
double y = yIn;
// are we to the left or right of the desired value -> move to the right or left to get
// closer
if (z >= p) {
while (true) {
if (y == 0 || (z = distributionFunc.eval(y - incr, true, false)) < p) {
return y;
}
y = RMath.fmax2(0, y - incr);
}
} else {
while (true) {
y = moveRight(y, incr);
if ((rightSearchLimit > 0 && y == rightSearchLimit) || (z = distributionFunc.eval(y, true, false)) >= p) {
return y;
}
}
}
}
private double moveRight(double y, double incr) {
if (rightSearchLimit < 0) {
return y + incr;
} else {
return RMath.fmin2(y + incr, rightSearchLimit);
}
}
@FunctionalInterface
public interface DistributionFunc {
double eval(double quantile, boolean lowerTail, boolean logP);
}
}
...@@ -49,6 +49,14 @@ public final class RMath { ...@@ -49,6 +49,14 @@ public final class RMath {
return Math.floor(x + 0.5); return Math.floor(x + 0.5);
} }
/**
* Implementation of C routine {@code round}, which is not equal to {@code Math.round}, because
* it returns {@code double} and so it can handle values that do not fit into long.
*/
public static double round(double x) {
return forceint(x);
}
public static double fsign(double x, double y) { public static double fsign(double x, double y) {
if (Double.isNaN(x) || Double.isNaN(y)) { if (Double.isNaN(x) || Double.isNaN(y)) {
return x + y; return x + y;
......
...@@ -41,8 +41,8 @@ public final class Wilcox { ...@@ -41,8 +41,8 @@ public final class Wilcox {
throw RError.error(RError.SHOW_CALLER, CALLOC_COULD_NOT_ALLOCATE_INF); throw RError.error(RError.SHOW_CALLER, CALLOC_COULD_NOT_ALLOCATE_INF);
} }
double m = Math.round(mIn); double m = RMath.round(mIn);
double n = Math.round(nIn); double n = RMath.round(nIn);
if ((m < 0) || (n < 0)) { if ((m < 0) || (n < 0)) {
// TODO: for some reason the macro in GNUR here returns NA instead of NaN... // TODO: for some reason the macro in GNUR here returns NA instead of NaN...
// return StatsUtil.mlError(); // return StatsUtil.mlError();
......
...@@ -32,6 +32,7 @@ import com.oracle.truffle.api.dsl.NodeChild; ...@@ -32,6 +32,7 @@ import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren; import com.oracle.truffle.api.dsl.NodeChildren;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.library.stats.GammaFunctions; import com.oracle.truffle.r.library.stats.GammaFunctions;
import com.oracle.truffle.r.library.stats.RMath;
import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.BaseGammaFunctionsFactory.DpsiFnCalcNodeGen; import com.oracle.truffle.r.nodes.builtin.base.BaseGammaFunctionsFactory.DpsiFnCalcNodeGen;
...@@ -235,7 +236,7 @@ public class BaseGammaFunctions { ...@@ -235,7 +236,7 @@ public class BaseGammaFunctions {
* use Abramowitz & Stegun 6.4.7 "Reflection Formula" psi(k, x) = (-1)^k psi(k, 1-x) * use Abramowitz & Stegun 6.4.7 "Reflection Formula" psi(k, x) = (-1)^k psi(k, 1-x)
* - pi^{n+1} (d/dx)^n cot(x) * - pi^{n+1} (d/dx)^n cot(x)
*/ */
if (x == Math.round(x)) { if (x == RMath.round(x)) {
/* non-positive integer : +Inf or NaN depends on n */ /* non-positive integer : +Inf or NaN depends on n */
// for(j=0; j < m; j++) /* k = j + n : */ // for(j=0; j < m; j++) /* k = j + n : */
// ans[j] = ((j+n) % 2) ? ML_POSINF : ML_NAN; // ans[j] = ((j+n) % 2) ? ML_POSINF : ML_NAN;
......
...@@ -115,7 +115,20 @@ public class TestDistributions extends TestBase { ...@@ -115,7 +115,20 @@ public class TestDistributions extends TestBase {
// this should show non-integer warnings for quantiles // this should show non-integer warnings for quantiles
test("5, 5, 5", withQuantiles("0.1", "-Inf", "Inf", "0.3e89")). test("5, 5, 5", withQuantiles("0.1", "-Inf", "Inf", "0.3e89")).
// too many drawn balls: should be error // too many drawn balls: should be error
test("3, 4, 10", withQuantiles("2")) test("3, 4, 10", withQuantiles("2")),
distr("pois").
addErrorParamValues("-1", "0").
test("10", withDefaultQ("5", "10", "15", "20", "30")).
// seems to be the smallest lambda for which we get some results other than 0/1
test("0.1e-6", withQuantiles("0.1e-10", "0.1", "1", "10")).
test("1e100", withQuantiles("1e99", "1e99*9.999", "1e100-1", "1e100", "1e100+100", "1e101")),
distr("binom").
addErrorParamValues("-1").
test("20, 0.3", withDefaultQ("1", "2", "10", "20", "21")).
test("10000, 0.01", withQuantiles("1", "10", "100", "500", "900", "1000")).
// non-probability value is error for the second parameter
test("10, -0.1", withQuantiles("2")).
test("10, 5", withQuantiles("2"))
}; };
// @formatter:on // @formatter:on
......
...@@ -46,7 +46,7 @@ public class TestStatFunctions extends TestBase { ...@@ -46,7 +46,7 @@ public class TestStatFunctions extends TestBase {
assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION3_1_NAMES, FUNCTION3_1_PARAMS)); assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION3_1_NAMES, FUNCTION3_1_PARAMS));
} }
private static final String[] FUNCTION2_1_NAMES = {"dchisq", "dgeom", "dpois", "dt"}; private static final String[] FUNCTION2_1_NAMES = {"dchisq", "dgeom", "dt"};
private static final String[] FUNCTION2_1_PARAMS = { private static final String[] FUNCTION2_1_PARAMS = {
"10, 10, log=TRUE", "10, 10, log=TRUE",
"3, 3, log=FALSE", "3, 3, log=FALSE",
...@@ -61,7 +61,7 @@ public class TestStatFunctions extends TestBase { ...@@ -61,7 +61,7 @@ public class TestStatFunctions extends TestBase {
assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION2_1_NAMES, FUNCTION2_1_PARAMS)); assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION2_1_NAMES, FUNCTION2_1_PARAMS));
} }
private static final String[] FUNCTION2_2_NAMES = {"pchisq", "qgeom", "pgeom", "qt", "pt", "qpois", "ppois", "qchisq"}; private static final String[] FUNCTION2_2_NAMES = {"pchisq", "qgeom", "pgeom", "qt", "pt", "qchisq"};
private static final String[] FUNCTION2_2_PARAMS = { private static final String[] FUNCTION2_2_PARAMS = {
"0, 10", "0, 10",
"c(-1, 0, 0.2, 2), rep(c(-1, 0, 0.1, 0.9, 3), 4)", "c(-1, 0, 0.2, 2), rep(c(-1, 0, 0.1, 0.9, 3), 4)",
...@@ -80,7 +80,7 @@ public class TestStatFunctions extends TestBase { ...@@ -80,7 +80,7 @@ public class TestStatFunctions extends TestBase {
assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION2_2_NAMES, new String[]{"rep(c(1, 0, 0.1), 5), c(NA, 0, NaN, 1/0, -1/0)"})); assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION2_2_NAMES, new String[]{"rep(c(1, 0, 0.1), 5), c(NA, 0, NaN, 1/0, -1/0)"}));
} }
private static final String[] FUNCTION3_2_NAMES = {"qlnorm", "plnorm", "qbinom", "qlogis", "pf", "pbinom", "plogis", "qf"}; private static final String[] FUNCTION3_2_NAMES = {"qlnorm", "plnorm", "qlogis", "pf", "plogis", "qf"};
private static final String[] FUNCTION3_2_PARAMS = { private static final String[] FUNCTION3_2_PARAMS = {
"0, 10, 10", "0, 10, 10",
"c(-1, 0, 0.2, 2), c(-1, 0, 0.1, 0.9, 3), rep(c(-1, 0, 1, 0.1, -0.1, 0.0001), 20)", "c(-1, 0, 0.2, 2), c(-1, 0, 0.1, 0.9, 3), rep(c(-1, 0, 1, 0.1, -0.1, 0.0001), 20)",
......
...@@ -79,6 +79,7 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/SplineFuncti ...@@ -79,6 +79,7 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/SplineFuncti
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java,gnu_r_gentleman_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java,gnu_r_gentleman_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java,gnu_r_gentleman_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java,gnu_r_gentleman_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMath.java,gnu_r_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMath.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/QuantileSearch.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMathError.java,gnu_r.core.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMathError.java,gnu_r.core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/LogNormal.java,gnu_r_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/LogNormal.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/MathConstants.java,gnu_r_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/MathConstants.java,gnu_r_ihaka.copyright
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment