diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java new file mode 100644 index 0000000000000000000000000000000000000000..9fcb9b77fdd42c164029cf11cc3c730f434bfcac --- /dev/null +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java @@ -0,0 +1,194 @@ +/* + * This material is distributed under the GNU General Public License + * Version 2. You may review the terms of this license at + * http://www.gnu.org/licenses/gpl-2.0.html + * + * Copyright (c) 1995, 1996, 1997 Robert Gentleman and Ross Ihaka + * Copyright (c) 1998-2013, The R Core Team + * Copyright (c) 2003-2015, The R Foundation + * Copyright (c) 2016, 2016, Oracle and/or its affiliates + * + * All rights reserved. + */ + +package com.oracle.truffle.r.library.stats; + +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue; +import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; +import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS; + +import java.util.Arrays; + +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory.ConvertToLengthNodeGen; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; +import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; +import com.oracle.truffle.r.nodes.unary.CastIntegerNode; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDouble; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.nodes.RNode; +import com.oracle.truffle.r.runtime.ops.na.NACheck; + +public final class RandGenerationFunctions { + private static final RDouble DUMMY_VECTOR = RDouble.valueOf(1); + + private RandGenerationFunctions() { + // static class + } + + // inspired by the DEFRAND{X}_REAL and DEFRAND{X}_INT macros in GnuR + + public interface RandFunction3_Int { + int evaluate(double a, double b, double c); + } + + public interface RandFunction2_Int extends RandFunction3_Int { + @Override + default int evaluate(double a, double b, double c) { + return evaluate(a, b); + } + + int evaluate(double a, double b); + } + + static final class RandGenerationProfiles { + final BranchProfile nanResult = BranchProfile.create(); + final BranchProfile nan = BranchProfile.create(); + final NACheck aCheck = NACheck.create(); + final NACheck bCheck = NACheck.create(); + final NACheck cCheck = NACheck.create(); + + public static RandGenerationProfiles create() { + return new RandGenerationProfiles(); + } + } + + private static RAbstractIntVector evaluate3Int(Node node, RandFunction3_Int function, int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, + RandGenerationProfiles profiles) { + int aLength = a.getLength(); + int bLength = b.getLength(); + int cLength = c.getLength(); + if (aLength == 0 || bLength == 0 || cLength == 0) { + profiles.nanResult.enter(); + RError.warning(SHOW_CALLER, RError.Message.NAN_PRODUCED); + int[] nansResult = new int[length]; + Arrays.fill(nansResult, RRuntime.INT_NA); + return RDataFactory.createIntVector(nansResult, false); + } + + RNode.reportWork(node, length); + boolean complete = true; + boolean nans = false; + profiles.aCheck.enable(a); + profiles.bCheck.enable(b); + profiles.cCheck.enable(c); + int[] result = new int[length]; + for (int i = 0; i < length; i++) { + double aValue = a.getDataAt(i % aLength); + double bValue = b.getDataAt(i % bLength); + double cValue = c.getDataAt(i % cLength); + int value; + if (Double.isNaN(aValue) || Double.isNaN(bValue) || Double.isNaN(cValue)) { + profiles.nan.enter(); + value = RRuntime.INT_NA; + if (profiles.aCheck.check(aValue) || profiles.bCheck.check(bValue) || profiles.cCheck.check(cValue)) { + complete = false; + } + } else { + value = function.evaluate(aValue, bValue, cValue); + if (Double.isNaN(value)) { + profiles.nan.enter(); + nans = true; + } + } + result[i] = value; + } + if (nans) { + RError.warning(SHOW_CALLER, RError.Message.NAN_PRODUCED); + } + return RDataFactory.createIntVector(result, complete); + } + + /** + * Converts given value to actual length that should be used as length of the output vector. The + * argument must be cast using {@link #addLengthCast(CastBuilder)}. Using this node allows us to + * avoid casting of long vectors to integers, if we only need to know their length. + */ + protected abstract static class ConvertToLength extends Node { + public abstract int execute(RAbstractVector value); + + @Specialization(guards = "vector.getLength() == 1") + public int lengthOne(RAbstractVector vector, + @Cached("createNonPreserving()") CastIntegerNode castNode, + @Cached("create()") BranchProfile seenNA) { + int result = ((RAbstractIntVector) castNode.execute(vector)).getDataAt(0); + if (RRuntime.isNA(result)) { + seenNA.enter(); + throw RError.error(SHOW_CALLER, INVALID_UNNAMED_ARGUMENTS); + } + return result; + } + + @Specialization(guards = "vector.getLength() != 1") + public int notSingle(RAbstractVector vector) { + return vector.getLength(); + } + + private static void addLengthCast(CastBuilder casts) { + casts.arg(0).defaultError(SHOW_CALLER, INVALID_UNNAMED_ARGUMENTS).mustBe(abstractVectorValue()).asVector(); + } + } + + public abstract static class Function3_IntNode extends RExternalBuiltinNode.Arg4 { + private final RandFunction3_Int function; + @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + + protected Function3_IntNode(RandFunction3_Int function) { + this.function = function; + } + + @Override + protected void createCasts(CastBuilder casts) { + ConvertToLength.addLengthCast(casts); + casts.arg(1).asDoubleVector(); + casts.arg(2).asDoubleVector(); + casts.arg(3).asDoubleVector(); + } + + @Specialization + protected RAbstractIntVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, + @Cached("create()") RandGenerationProfiles profiles) { + return evaluate3Int(this, function, convertToLength.execute(length), a, b, c, profiles); + } + } + + public abstract static class Function2_IntNode extends RExternalBuiltinNode.Arg3 { + private final RandFunction2_Int function; + @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + + protected Function2_IntNode(RandFunction2_Int function) { + this.function = function; + } + + @Override + protected void createCasts(CastBuilder casts) { + ConvertToLength.addLengthCast(casts); + casts.arg(1).asDoubleVector(); + casts.arg(2).asDoubleVector(); + } + + @Specialization + protected RAbstractIntVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, + @Cached("create()") RandGenerationProfiles profiles) { + return evaluate3Int(this, function, convertToLength.execute(length), a, b, DUMMY_VECTOR, profiles); + } + } +} diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java index 74cf21b87fcb1104e4b3d67104eec9f6400a142c..e814c90a222325d9ba8f6e91524cbbdd602b2f08 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java @@ -13,62 +13,50 @@ package com.oracle.truffle.r.library.stats; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; -import com.oracle.truffle.api.dsl.Cached; -import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.api.profiles.BranchProfile; -import com.oracle.truffle.r.nodes.builtin.CastBuilder; -import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; -import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; -import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RError.Message; -import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.nodes.RNode; -import com.oracle.truffle.r.runtime.ops.na.NAProfile; +import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Int; +import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.rng.RRNG; // transcribed from rbinom.c -public abstract class Rbinom extends RExternalBuiltinNode.Arg3 { +public final class Rbinom implements RandFunction2_Int { + + private final Qbinom qbinom = new Qbinom(); @TruffleBoundary private static double unifRand() { return RRNG.unifRand(); } - private final Qbinom qbinom = new Qbinom(); - - double rbinom(double nin, double pp, BranchProfile nanProfile) { + @Override + public int evaluate(double nin, double pp) { double psave = -1.0; int nsave = -1; if (!Double.isFinite(nin)) { - nanProfile.enter(); - return Double.NaN; + return RRuntime.INT_NA; } double r = MathConstants.forceint(nin); if (r != nin) { - nanProfile.enter(); - return Double.NaN; + return RRuntime.INT_NA; } /* n=0, p=0, p=1 are not errors <TSL> */ if (!Double.isFinite(pp) || r < 0 || pp < 0. || pp > 1.) { - nanProfile.enter(); - return Double.NaN; + return RRuntime.INT_NA; } if (r == 0 || pp == 0.) { return 0; } if (pp == 1.) { - return r; + return (int) r; } if (r >= Integer.MAX_VALUE) { /* * evade integer overflow, and r == INT_MAX gave only even values */ - return qbinom.evaluate(unifRand(), r, pp, /* lower_tail */false, /* log_p */false); + return (int) qbinom.evaluate(unifRand(), r, pp, /* lower_tail */false, /* log_p */false); } /* else */ int n = (int) r; @@ -258,47 +246,4 @@ public abstract class Rbinom extends RExternalBuiltinNode.Arg3 { } return ix; } - - @Override - protected void createCasts(CastBuilder casts) { - casts.arg(0).asDoubleVector(); - casts.arg(1).asDoubleVector(); - casts.arg(2).asDoubleVector(); - } - - @Specialization - protected Object rbinom(RAbstractDoubleVector n, RAbstractDoubleVector size, RAbstractDoubleVector prob, // - @Cached("create()") NAProfile na, // - @Cached("create()") BranchProfile nanProfile, // - @Cached("create()") VectorLengthProfile sizeProfile, // - @Cached("create()") VectorLengthProfile probProfile) { - int length = n.getLength(); - RNode.reportWork(this, length); - if (length == 1) { - double l = n.getDataAt(0); - if (Double.isNaN(l) || l < 0 || l > Integer.MAX_VALUE) { - throw RError.error(RError.SHOW_CALLER, Message.INVALID_UNNAMED_ARGUMENTS); - } - length = (int) l; - } - int sizeLength = sizeProfile.profile(size.getLength()); - int probLength = probProfile.profile(prob.getLength()); - - double[] result = new double[length]; - boolean complete = true; - boolean nans = false; - for (int i = 0; i < length; i++) { - double value = rbinom(size.getDataAt(i % sizeLength), prob.getDataAt(i % probLength), nanProfile); - if (na.isNA(value)) { - complete = false; - } else if (Double.isNaN(value)) { - nans = true; - } - result[i] = value; - } - if (nans) { - RError.warning(RError.SHOW_CALLER, RError.Message.NAN_PRODUCED); - } - return RDataFactory.createDoubleVector(result, complete); - } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java index 00b577f8bc9dc53b4737f7d39fc6152c1826d916..646e6fabcc6e88a0c9ae53baec87e926cb46676d 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/ForeignFunctions.java @@ -51,7 +51,8 @@ import com.oracle.truffle.r.library.stats.Pf; import com.oracle.truffle.r.library.stats.Pnorm; import com.oracle.truffle.r.library.stats.Qbinom; import com.oracle.truffle.r.library.stats.Qnorm; -import com.oracle.truffle.r.library.stats.RbinomNodeGen; +import com.oracle.truffle.r.library.stats.RandGenerationFunctionsFactory; +import com.oracle.truffle.r.library.stats.Rbinom; import com.oracle.truffle.r.library.stats.RnormNodeGen; import com.oracle.truffle.r.library.stats.RunifNodeGen; import com.oracle.truffle.r.library.stats.SplineFunctionsFactory.SplineCoefNodeGen; @@ -368,7 +369,7 @@ public class ForeignFunctions { case "qbinom": return StatsFunctionsFactory.Function3_2NodeGen.create(new Qbinom()); case "rbinom": - return RbinomNodeGen.create(); + return RandGenerationFunctionsFactory.Function2_IntNodeGen.create(new Rbinom()); case "pbinom": return StatsFunctionsFactory.Function3_2NodeGen.create(new Pbinom()); case "pf": diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbinom.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbinom.java new file mode 100644 index 0000000000000000000000000000000000000000..79296c9ac3e58bd59a3d54e33bd25095acc260cc --- /dev/null +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_rbinom.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package com.oracle.truffle.r.test.library.stats; + +import org.junit.Test; + +import com.oracle.truffle.r.test.TestBase; + +public class TestExternal_rbinom extends TestBase { + @Test + public void testRbinom() { + assertEval("set.seed(42); rbinom(10, 10, 0.5)"); + assertEval("set.seed(42); rbinom('10', 10, 0.5)"); + assertEval(Output.IgnoreWarningContext, "set.seed(42); rbinom('aa', 10, 0.5)"); + assertEval("set.seed(42); rbinom(10, 2:10, c(0.1, 0.5, 0.9))"); + assertEval("set.seed(42); rbinom(1:10, 2:10, c(0.1, 0.5, 0.9))"); + assertEval("set.seed(42); rbinom(c(1,2), 11:12, c(0.1, 0.5, 0.9))"); + } +} diff --git a/mx.fastr/copyrights/overrides b/mx.fastr/copyrights/overrides index 204b1889bef3b1be2f1fd565c36f48e20bcef9f5..b2774bbe81df2fcdb8776e053ecbdb3aec5e15f4 100644 --- a/mx.fastr/copyrights/overrides +++ b/mx.fastr/copyrights/overrides @@ -47,6 +47,7 @@ com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rbinom.java, com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Rnorm.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/SplineFunctions.java,gnu_r_splines.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java,gnu_r_gentleman_ihaka.copyright +com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandGenerationFunctions.java,gnu_r_gentleman_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsUtil.java,gnu_r_ihaka.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/TOMS708.java,gnu_r.copyright com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/tools/DirChmod.java,gnu_r.copyright