Skip to content
Snippets Groups Projects
Commit e1295d32 authored by stepan's avatar stepan
Browse files

Stat externals: qbeta, qf, qchisq

+ necessary DPQ/RMath methods
+ move all XChisq to Chisq.java
parent d4dedb8b
No related branches found
No related tags found
No related merge requests found
Showing
with 1647 additions and 43 deletions
......@@ -13,11 +13,18 @@ package com.oracle.truffle.r.library.stats;
import static com.oracle.truffle.r.library.stats.GammaFunctions.dgamma;
import static com.oracle.truffle.r.library.stats.GammaFunctions.pgamma;
import static com.oracle.truffle.r.library.stats.GammaFunctions.qgamma;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1_Double;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider;
import com.oracle.truffle.r.library.stats.StatsFunctions.Function2_1;
import com.oracle.truffle.r.library.stats.StatsFunctions.Function2_2;
public final class Chisq {
private Chisq() {
// only static members
}
public static final class PChisq implements Function2_2 {
@Override
public double evaluate(double x, double df, boolean lowerTail, boolean logP) {
......@@ -31,4 +38,25 @@ public final class Chisq {
return dgamma(x, df / 2., 2., giveLog);
}
}
public static final class QChisq implements Function2_2 {
@Override
public double evaluate(double p, double df, boolean lowerTail, boolean logP) {
return qgamma(p, 0.5 * df, 2.0, lowerTail, logP);
}
}
public static final class RChisq extends RandFunction1_Double {
public static double rchisq(double df, RandomNumberProvider rand) {
if (!Double.isFinite(df) || df < 0.0) {
return RMath.mlError();
}
return new RGamma().execute(df / 2.0, 2.0, rand);
}
@Override
public double execute(double a, RandomNumberProvider rand) {
return rchisq(a, rand);
}
}
}
......@@ -51,6 +51,11 @@ public final class DPQ {
return logP ? Double.NEGATIVE_INFINITY : 0.;
}
// R_D_half (log_p ? -M_LN2 : 0.5)
public static double rdhalf(boolean logP) {
return logP ? -M_LN2 : 0.5;
}
// R_D__1
public static double rd1(boolean logP) {
return logP ? 0. : 1.;
......
......@@ -86,6 +86,8 @@ public final class MathConstants {
public static final double DBL_EPSILON = Math.ulp(1.0);
public static final double ML_NAN = Double.NaN;
/**
* Compute the log of a sum from logs of terms, i.e.,
*
......
......@@ -29,7 +29,7 @@ public final class Pbeta implements Function3_2 {
}
@TruffleBoundary
private static double pbetaRaw(double x, double a, double b, boolean lowerTail, boolean logProb) {
static double pbetaRaw(double x, double a, double b, boolean lowerTail, boolean logProb) {
// treat limit cases correctly here:
if (a == 0 || b == 0 || !Double.isFinite(a) || !Double.isFinite(b)) {
// NB: 0 < x < 1 :
......
/*
* This material is distributed under the GNU General Public License
* Version 2. You may review the terms of this license at
* http://www.gnu.org/licenses/gpl-2.0.html
*
* Copyright (C) 1998 Ross Ihaka
* Copyright (c) 2000--2015, The R Core Team
* Copyright (c) 2005, The R Foundation
* Copyright (c) 2016, Oracle and/or its affiliates
*
* All rights reserved.
*/
package com.oracle.truffle.r.library.stats;
import com.oracle.truffle.r.library.stats.Chisq.QChisq;
import com.oracle.truffle.r.library.stats.DPQ.EarlyReturn;
import com.oracle.truffle.r.library.stats.StatsFunctions.Function3_2;
public final class Qf implements Function3_2 {
private final QBeta qbeta = new QBeta();
private final QChisq qchisq = new QChisq();
@Override
public double evaluate(double p, double df1, double df2, boolean lowerTail, boolean logP) {
if (Double.isNaN(p) || Double.isNaN(df1) || Double.isNaN(df2)) {
return p + df1 + df2;
}
if (df1 <= 0. || df2 <= 0.) {
return RMath.mlError();
}
try {
DPQ.rqp01boundaries(p, 0, Double.POSITIVE_INFINITY, lowerTail, logP);
} catch (EarlyReturn e) {
return e.result;
}
/*
* fudge the extreme DF cases -- qbeta doesn't do this well. But we still need to fudge the
* infinite ones.
*/
if (df1 <= df2 && df2 > 4e5) {
if (!Double.isFinite(df1)) { /* df1 == df2 == Inf : */
return 1.;
} else {
return qchisq.evaluate(p, df1, lowerTail, logP) / df1;
}
}
if (df1 > 4e5) { /* and so df2 < df1 */
return df2 / qchisq.evaluate(p, df2, !lowerTail, logP);
}
// FIXME: (1/qb - 1) = (1 - qb)/qb; if we know qb ~= 1, should use other tail
p = (1. / qbeta.evaluate(p, df2 / 2, df1 / 2, !lowerTail, logP) - 1.) * (df2 / df1);
return RMath.mlValid(p) ? p : Double.NaN;
}
}
/*
* This material is distributed under the GNU General Public License
* Version 2. You may review the terms of this license at
* http://www.gnu.org/licenses/gpl-2.0.html
*
* Copyright (C) 1998 Ross Ihaka
* Copyright (c) 1998--2008, The R Core Team
* Copyright (c) 2016, 2016, Oracle and/or its affiliates
*
* All rights reserved.
*/
package com.oracle.truffle.r.library.stats;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1_Double;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider;
public final class RChisq extends RandFunction1_Double {
public static double rchisq(double df, RandomNumberProvider rand) {
if (!Double.isFinite(df) || df < 0.0) {
return RMath.mlError();
}
return new RGamma().execute(df / 2.0, 2.0, rand);
}
@Override
public double execute(double a, RandomNumberProvider rand) {
return rchisq(a, rand);
}
}
......@@ -16,6 +16,7 @@ import static com.oracle.truffle.r.library.stats.LBeta.lbeta;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime;
/**
......@@ -26,13 +27,36 @@ import com.oracle.truffle.r.runtime.RRuntime;
*/
public class RMath {
public enum MLError {
DOMAIN,
RANGE,
NOCONV,
PRECISION,
UNDERFLOW
}
/**
* corresponds to macro {@code ML_ERR_return_NAN} in GnuR.
* Corresponds to macro {@code ML_ERR_return_NAN} in GnuR.
*/
public static double mlError() {
return mlError(MLError.DOMAIN, "");
}
/**
* Corresponds to macro {@code ML_ERR} in GnuR. TODO: raise corresponding warning
*/
public static double mlError(@SuppressWarnings("unused") MLError error, @SuppressWarnings("unused") String message) {
return Double.NaN;
}
public static void mlWarning(RError.Message message, Object... args) {
RError.warning(null, message, args);
}
public static boolean mlValid(double d) {
return !Double.isNaN(d);
}
public static double lfastchoose(double n, double k) {
return -Math.log(n + 1.) - lbeta(n - k + 1., k + 1.);
}
......@@ -56,7 +80,7 @@ public class RMath {
return ((y >= 0) ? TOMS708.fabs(x) : -TOMS708.fabs(x));
}
public static double fmod(double a, double b) {
private static double fmod(double a, double b) {
double q = a / b;
if (b != 0) {
double tmp = a - Math.floor(q) * b;
......
......@@ -30,7 +30,7 @@ public final class RNchisq extends RandFunction2_Double {
} else {
double r = RPois.rpois(lambda / 2., rand);
if (r > 0.) {
r = RChisq.rchisq(2. * r, rand);
r = Chisq.RChisq.rchisq(2. * r, rand);
}
if (df > 0.) {
r += rgamma.execute(df / 2., 2., rand);
......
......@@ -23,8 +23,8 @@ public final class Rf extends RandFunction2_Double {
double v1;
double v2;
v1 = Double.isFinite(n1) ? (RChisq.rchisq(n1, rand) / n1) : 1;
v2 = Double.isFinite(n2) ? (RChisq.rchisq(n2, rand) / n2) : 1;
v1 = Double.isFinite(n1) ? (Chisq.RChisq.rchisq(n1, rand) / n1) : 1;
v2 = Double.isFinite(n2) ? (Chisq.RChisq.rchisq(n2, rand) / n2) : 1;
return v1 / v2;
}
}
......@@ -11,7 +11,7 @@
*/
package com.oracle.truffle.r.library.stats;
import static com.oracle.truffle.r.library.stats.RChisq.rchisq;
import static com.oracle.truffle.r.library.stats.Chisq.RChisq.rchisq;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction1_Double;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandomNumberProvider;
......
......@@ -46,6 +46,7 @@ import com.oracle.truffle.r.library.stats.Cauchy.PCauchy;
import com.oracle.truffle.r.library.stats.Cauchy.RCauchy;
import com.oracle.truffle.r.library.stats.CdistNodeGen;
import com.oracle.truffle.r.library.stats.Chisq;
import com.oracle.truffle.r.library.stats.Chisq.RChisq;
import com.oracle.truffle.r.library.stats.CompleteCases;
import com.oracle.truffle.r.library.stats.CovcorNodeGen;
import com.oracle.truffle.r.library.stats.CutreeNodeGen;
......@@ -77,12 +78,13 @@ import com.oracle.truffle.r.library.stats.Pbinom;
import com.oracle.truffle.r.library.stats.Pf;
import com.oracle.truffle.r.library.stats.Pnorm;
import com.oracle.truffle.r.library.stats.Pt;
import com.oracle.truffle.r.library.stats.QBeta;
import com.oracle.truffle.r.library.stats.QPois;
import com.oracle.truffle.r.library.stats.Qbinom;
import com.oracle.truffle.r.library.stats.Qf;
import com.oracle.truffle.r.library.stats.Qnorm;
import com.oracle.truffle.r.library.stats.Qt;
import com.oracle.truffle.r.library.stats.RBeta;
import com.oracle.truffle.r.library.stats.RChisq;
import com.oracle.truffle.r.library.stats.RGamma;
import com.oracle.truffle.r.library.stats.RHyper;
import com.oracle.truffle.r.library.stats.RMultinomNodeGen;
......@@ -151,7 +153,7 @@ import com.oracle.truffle.r.runtime.ffi.RFFIFactory;
public class CallAndExternalFunctions {
@TruffleBoundary
protected static Object encodeArgumentPairList(RArgsValuesAndNames args, String symbolName) {
private static Object encodeArgumentPairList(RArgsValuesAndNames args, String symbolName) {
Object list = RNull.instance;
for (int i = args.getLength() - 1; i >= 0; i--) {
String name = args.getSignature().getName(i);
......@@ -161,8 +163,8 @@ public class CallAndExternalFunctions {
return list;
}
protected abstract static class CallRFFIAdapter extends LookupAdapter {
@Child protected CallRFFI.CallRFFINode callRFFINode = RFFIFactory.getRFFI().getCallRFFI().createCallRFFINode();
abstract static class CallRFFIAdapter extends LookupAdapter {
@Child CallRFFI.CallRFFINode callRFFINode = RFFIFactory.getRFFI().getCallRFFI().createCallRFFINode();
}
/**
......@@ -321,6 +323,8 @@ public class CallAndExternalFunctions {
return StatsFunctionsFactory.Function3_2NodeGen.create(new Pbinom());
case "pbeta":
return StatsFunctionsFactory.Function3_2NodeGen.create(new Pbeta());
case "qbeta":
return StatsFunctionsFactory.Function3_2NodeGen.create(new QBeta());
case "dcauchy":
return StatsFunctionsFactory.Function3_1NodeGen.create(new DCauchy());
case "pcauchy":
......@@ -329,12 +333,16 @@ public class CallAndExternalFunctions {
return StatsFunctionsFactory.Function3_2NodeGen.create(new Cauchy.QCauchy());
case "pf":
return StatsFunctionsFactory.Function3_2NodeGen.create(new Pf());
case "qf":
return StatsFunctionsFactory.Function3_2NodeGen.create(new Qf());
case "df":
return StatsFunctionsFactory.Function3_1NodeGen.create(new Df());
case "dgamma":
return StatsFunctionsFactory.Function3_1NodeGen.create(new DGamma());
case "dchisq":
return StatsFunctionsFactory.Function2_1NodeGen.create(new Chisq.DChisq());
case "qchisq":
return StatsFunctionsFactory.Function2_2NodeGen.create(new Chisq.QChisq());
case "qgeom":
return StatsFunctionsFactory.Function2_2NodeGen.create(new Geom.QGeom());
case "pchisq":
......
......@@ -527,6 +527,7 @@ public final class RError extends RuntimeException {
NA_IN_PROB_VECTOR("NA in probability vector"),
NEGATIVE_PROBABILITY("negative probability"),
NO_POSITIVE_PROBABILITIES("no positive probabilities"),
QBETA_ACURACY_WARNING("qbeta(a, *) =: x0 with |pbeta(x0,*%s) - alpha| = %.5g is not accurate"),
NON_POSITIVE_FILL("non-positive 'fill' argument will be ignored"),
MUST_BE_ONE_BYTE("invalid %s: must be one byte"),
INVALID_DECIMAL_SEP("invalid decimal separator"),
......
/*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.test.library.stats;
import static com.oracle.truffle.r.test.library.stats.TestStatFunctions.PROBABILITIES;
import org.junit.Test;
import com.oracle.truffle.r.test.TestBase;
/**
* Additional tests on the top of {@link TestStatFunctions}.
*/
public class TestExternal_qbeta extends TestBase {
private static final String[] BOOL_VALUES = new String[]{"T", "F"};
@Test
public void testQBeta() {
// check if in qbeta_raw with comment "p==0, q==0, p = Inf, q = Inf <==> treat as one- or
// two-point mass"
assertEval(template("qbeta(0.7, 0, 1/0, lower.tail=%0, log.p=F)", BOOL_VALUES));
assertEval(template("qbeta(log(0.7), 0, 1/0, lower.tail=%0, log.p=T)", BOOL_VALUES));
assertEval(template("qbeta(0.7, 1/0, 0, lower.tail=%0, log.p=F)", BOOL_VALUES));
assertEval(template("qbeta(log(0.7), 1/0, 0, lower.tail=%0, log.p=T)", BOOL_VALUES));
assertEval(template("qbeta(%0, 0, 0, lower.tail=%1, log.p=F)", new String[]{"0.1", "0.5", "0.7"}, BOOL_VALUES));
assertEval(template("qbeta(log(%0), 0, 0, lower.tail=%1, log.p=T)", new String[]{"0.1", "0.5", "0.7"}, BOOL_VALUES));
// checks swap_tail = (p_ > 0.5) where for lower.tail = log.p = TRUE is p_ = exp(p)
// exp(0.1) = 1.10....
assertEval(template("qbeta(log(%0), 0.1, 3, lower.tail=T, log.p=T)", PROBABILITIES));
// exp(-1) = 0.36...
assertEval(Output.MayIgnoreWarningContext, template("qbeta(log(%0), -1, 3, lower.tail=T, log.p=T)", PROBABILITIES));
}
}
......@@ -30,6 +30,8 @@ import com.oracle.truffle.r.test.TestBase;
* Common tests for functions implemented using {@code StatsFunctions} infrastructure.
*/
public class TestStatFunctions extends TestBase {
public static final String[] PROBABILITIES = new String[]{"0", "42e-80", "0.1", "0.5", "0.7", "1-42e-80", "1"};
private static final String[] FUNCTION3_1_NAMES = {"dgamma", "dbeta", "dcauchy", "dlnorm", "dlogis", "dunif"};
private static final String[] FUNCTION3_1_PARAMS = {
"10, 10, 10, log=TRUE",
......@@ -61,7 +63,7 @@ public class TestStatFunctions extends TestBase {
assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION2_1_NAMES, FUNCTION2_1_PARAMS));
}
private static final String[] FUNCTION2_2_NAMES = {"pchisq", "pexp", "qexp", "qgeom", "pgeom", "qt", "pt", "qpois", "ppois"};
private static final String[] FUNCTION2_2_NAMES = {"pchisq", "pexp", "qexp", "qgeom", "pgeom", "qt", "pt", "qpois", "ppois", "qchisq"};
private static final String[] FUNCTION2_2_PARAMS = {
"0, 10",
"c(-1, 0, 0.2, 2), rep(c(-1, 0, 0.1, 0.9, 3), 4)",
......@@ -80,7 +82,8 @@ public class TestStatFunctions extends TestBase {
assertEval(Output.IgnoreWhitespace, template("set.seed(1); %0(%1)", FUNCTION2_2_NAMES, new String[]{"rep(c(1, 0, 0.1), 5), c(NA, 0, NaN, 1/0, -1/0)"}));
}
private static final String[] FUNCTION3_2_NAMES = {"pbeta", "pcauchy", "qcauchy", "qlnorm", "plnorm", "qbinom", "pnorm", "qnorm", "qlogis", "pf", "pbinom", "plogis", "punif", "qunif"};
private static final String[] FUNCTION3_2_NAMES = {"pbeta", "pcauchy", "qcauchy", "qlnorm", "plnorm", "qbinom", "pnorm", "qnorm", "qlogis", "pf", "pbinom", "plogis", "punif", "qunif", "qbeta",
"qf"};
private static final String[] FUNCTION3_2_PARAMS = {
"0, 10, 10",
"c(-1, 0, 0.2, 2), c(-1, 0, 0.1, 0.9, 3), rep(c(-1, 0, 1, 0.1, -0.1, 0.0001), 20)",
......
......@@ -52,6 +52,10 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/LBeta.java,g
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Pbinom.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Pf.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Pnorm.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/PPois.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/QPois.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/QBeta.java,gnu_r_scan.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Qf.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Qbinom.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Qnorm.java,gnu_r_ihaka.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Random2.java,gnu_r.copyright
......@@ -73,7 +77,6 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RGamma.java,
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RNbinomMu.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Logis.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rf.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RChisq.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Exp.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Geom.java,gnu_r_ihaka_core.copyright
com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Dt.java,gnu_r.core.copyright
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment