From f6e24e87cf156c58bffb22b6f5227ed535e0a1ea Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Wed, 26 Oct 2016 15:23:54 +0200 Subject: [PATCH] Sample2 builtin implemented --- .../r/nodes/builtin/base/BasePackage.java | 1 + .../truffle/r/nodes/builtin/base/Sample2.java | 115 ++++++++++++++++++ .../com/oracle/truffle/r/runtime/RError.java | 1 + .../truffle/r/test/ExpectedTestOutput.test | 41 +++++++ .../r/test/builtins/TestBuiltin_sample2.java | 50 ++++++++ mx.fastr/copyrights/overrides | 1 + 6 files changed, 209 insertions(+) create mode 100644 com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample2.java create mode 100644 com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sample2.java diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java index 6dbe398cbd..7af620d538 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java @@ -559,6 +559,7 @@ public class BasePackage extends RBuiltinPackage { add(S3DispatchFunctions.NextMethod.class, S3DispatchFunctionsFactory.NextMethodNodeGen::create); add(S3DispatchFunctions.UseMethod.class, S3DispatchFunctionsFactory.UseMethodNodeGen::create); add(Sample.class, SampleNodeGen::create); + add(Sample2.class, Sample2NodeGen::create); add(Scan.class, ScanNodeGen::create); add(Seq.class, SeqNodeGen::create); add(SeqAlong.class, SeqAlongNodeGen::create); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample2.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample2.java new file mode 100644 index 0000000000..5efd497c9c --- /dev/null +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample2.java @@ -0,0 +1,115 @@ +/* + * This material is distributed under the GNU General Public License + * Version 2. You may review the terms of this license at + * http://www.gnu.org/licenses/gpl-2.0.html + * + * Copyright (c) 1995, 1996 Robert Gentleman and Ross Ihaka + * Copyright (c) 1997-2015, The R Core Team + * Copyright (c) 2016, Oracle and/or its affiliates + * + * All rights reserved. + */ + +package com.oracle.truffle.r.nodes.builtin.base; + +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.isFinite; +import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; +import static com.oracle.truffle.r.runtime.RError.Message.ALGORITHM_FOR_SIZE_N_DIV_2; +import static com.oracle.truffle.r.runtime.RError.Message.INVALID_ARGUMENT; +import static com.oracle.truffle.r.runtime.RError.Message.INVALID_FIRST_ARGUMENT; +import static com.oracle.truffle.r.runtime.builtins.RBehavior.MODIFIES_STATE; +import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; + +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.dsl.TypeSystemReference; +import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.r.nodes.EmptyTypeSystemFlatLayout; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; +import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.Utils.NonRecursiveHashSetDouble; +import com.oracle.truffle.r.runtime.Utils.NonRecursiveHashSetInt; +import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.RIntVector; +import com.oracle.truffle.r.runtime.rng.RRNG; + +/** + * Sample2 is more efficient special case implementation of {@link Sample}. + */ +@RBuiltin(name = "sample2", kind = INTERNAL, parameterNames = {"x", "size"}, behavior = MODIFIES_STATE) +@TypeSystemReference(EmptyTypeSystemFlatLayout.class) +public abstract class Sample2 extends RBuiltinNode { + private static final double U = 33554432.0; + static final double MAX_INT = Integer.MAX_VALUE; + + private final BranchProfile errorProfile = BranchProfile.create(); + + @Override + protected void createCasts(CastBuilder casts) { + // @formatter:off + casts.arg("x").defaultError(SHOW_CALLER, INVALID_FIRST_ARGUMENT).allowNull(). + mustBe(integerValue().or(doubleValue())).notNA(SHOW_CALLER, INVALID_FIRST_ARGUMENT). + asDoubleVector().findFirst().mustBe(gte(0.0)).mustBe(isFinite()); + casts.arg("size").defaultError(SHOW_CALLER, INVALID_ARGUMENT, "size"). + mustBe(integerValue().or(doubleValue())). + asIntegerVector().findFirst(). + defaultError(SHOW_CALLER, INVALID_ARGUMENT, "size"). + notNA().mustBe(gte0()); + // @formatter:on + } + + @Specialization(guards = "x > MAX_INT") + protected RDoubleVector doLargeX(double x, int size) { + validate(x, size); + RRNG.getRNGState(); + + double[] result = new double[size]; + NonRecursiveHashSetDouble used = new NonRecursiveHashSetDouble((int) (size * 1.2)); + for (int i = 0; i < size; i++) { + for (int j = 0; j < 100; j++) { + double value = Math.floor(x * ru() + 1); + if (!used.add(value)) { + result[i] = value; + break; + } + } + } + return RDataFactory.createDoubleVector(result, true); + } + + @Specialization(guards = "x <= MAX_INT") + protected RIntVector doSmallX(double x, int size) { + validate(x, size); + RRNG.getRNGState(); + + int[] result = new int[size]; + NonRecursiveHashSetInt used = new NonRecursiveHashSetInt((int) (size * 1.2)); + for (int i = 0; i < size; i++) { + for (int j = 0; j < 100; j++) { + int value = (int) (x * RRNG.unifRand() + 1); + if (!used.add(value)) { + result[i] = value; + break; + } + } + } + return RDataFactory.createIntVector(result, true); + } + + private void validate(double x, int size) { + if (size > x / 2) { + errorProfile.enter(); + throw RError.error(SHOW_CALLER, ALGORITHM_FOR_SIZE_N_DIV_2); + } + } + + private double ru() { + return (Math.floor(U * RRNG.unifRand()) + RRNG.unifRand()) / U; + } +} diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java index e1399d451d..1d20574610 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java @@ -526,6 +526,7 @@ public final class RError extends RuntimeException { INVALID_TIES_FOR_RANK("invalid ties.method for rank() [should never happen]"), UNIMPLEMENTED_TYPE_IN_GREATER("unimplemented type '%s' in greater"), RANK_LARGE_N("parameter 'n' is greater than length(x), GnuR output is non-deterministic, FastR will use n=length(x)"), + ALGORITHM_FOR_SIZE_N_DIV_2("This algorithm is for size <= n/2"), // below: not exactly GNU-R message TOO_FEW_POSITIVE_PROBABILITY("too few positive probabilities"), DOTS_BOUNDS("The ... list does not contain %s elements"), diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index cf28974969..1b79171bc9 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -42986,6 +42986,47 @@ integer(0) #argv <- structure(list(x = c(0, 0)), .Names = 'x');do.call('sample', argv) [1] 0 0 +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testArgsCasts +#set.seed(42); .Internal(sample2(-2, 1)) +Error: invalid first argument + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testArgsCasts +#set.seed(42); .Internal(sample2(-2L, 1)) +Error: invalid first argument + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testArgsCasts +#set.seed(42); .Internal(sample2(10, -2)) +Error: invalid 'size' argument + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testArgsCasts +#set.seed(42); .Internal(sample2(10, 2.99)) +[1] 10 3 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testArgsCasts +#set.seed(42); .Internal(sample2(10, 8)) +Error: This algorithm is for size <= n/2 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testArgsCasts +#set.seed(42); .Internal(sample2(NA, 1)) +Error: invalid first argument + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testArgsCasts +#set.seed(42); .Internal(sample2(NaN, 1)) +Error: invalid first argument + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testSample2 +#set.seed(42); .Internal(sample2(10, 2)) +[1] 10 3 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testSample2 +#set.seed(42); .Internal(sample2(10L, 3L)) +[1] 10 3 9 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_sample2.testSample2 +#set.seed(42); .Internal(sample2(4147483647, 10)) + [1] 3794143200 1186759075 2661629063 3054987943 2724864852 1898476583 + [7] 3876537848 1917352029 4057178109 1970042664 + ##com.oracle.truffle.r.test.builtins.TestBuiltin_scale.testscale1 #argv <- structure(list(x = structure(c(0.0280387932434678, 0.789736648323014, 0.825624888762832, 0.102816025260836, 0.290661531267688, 0.0517604837659746, 0.610383243998513, 0.78207225818187, 0.136790128657594, 0.8915234063752, 0.0216042066458613, 0.408875584136695, 0.69190051057376, 0.595735886832699, 0.936268283519894, 0.592950375983492, 0.852736486820504, 0.610123937483877, 0.600582004291937, 0.38303488586098, 0.412859325064346, 0.388432375853881, 0.457582515198737, 0.701614629011601, 0.449137942166999, 0.533179924823344, 0.317685069283471, 0.800954289967194, 0.0273033923003823, 0.496913943905383, 0.903582146391273, 0.725298138801008, 0.616459952667356, 0.341360273305327, 0.0613401387818158, 0.7339238144923, 0.720672776456922, 0.214702291414142, 0.283225567312911, 0.515186718199402, 0.558621872216463, 0.770191126968712, 0.959201833466068, 0.80451478343457, 0.307586128590629, 0.902739278972149, 0.992322677979246, 0.167487781029195, 0.796250741928816, 0.549091263208538, 0.0876540709286928, 0.424049312015995, 0.573274190537632, 0.763274750672281, 0.405174027662724, 0.828049632022157, 0.128607030957937, 0.479592794785276, 0.631105397362262, 0.406053610146046, 0.661386628635228, 0.958720558788627, 0.576542558381334, 0.0483133427333087, 0.615997062064707, 0.341076754732057, 0.901286069769412, 0.521056747529656, 0.92834516079165, 0.228773980634287, 0.458389508537948, 0.987496873131022, 0.0315267851110548, 0.872887850506231, 0.59517983533442, 0.935472247190773, 0.145392092177644, 0.255368477664888, 0.322336541488767, 0.507066876627505, 0.0745627176947892, 0.0313172969035804, 0.499229126842692, 0.868204665370286, 0.232835006900132, 0.422810809221119, 0.803322346881032, 0.00151223805733025, 0.175151102710515, 0.469289294909686), .Dim = c(10L, 9L))), .Names = 'x');do.call('scale', argv) [,1] [,2] [,3] [,4] [,5] [,6] diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sample2.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sample2.java new file mode 100644 index 0000000000..6e79ab8063 --- /dev/null +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sample2.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.test.builtins; + +import org.junit.Test; + +import com.oracle.truffle.r.test.TestBase; + +// Checkstyle: stop line length check +public class TestBuiltin_sample2 extends TestBase { + @Test + public void testSample2() { + assertEval("set.seed(42); .Internal(sample2(10, 2))"); + assertEval("set.seed(42); .Internal(sample2(10L, 3L))"); + // test with n > MAX_INT + assertEval("set.seed(42); .Internal(sample2(4147483647, 10))"); + } + + @Test + public void testArgsCasts() { + assertEval("set.seed(42); .Internal(sample2(-2, 1))"); + assertEval("set.seed(42); .Internal(sample2(-2L, 1))"); + assertEval("set.seed(42); .Internal(sample2(NA, 1))"); + assertEval("set.seed(42); .Internal(sample2(NaN, 1))"); + + assertEval("set.seed(42); .Internal(sample2(10, 8))"); + assertEval("set.seed(42); .Internal(sample2(10, -2))"); + assertEval("set.seed(42); .Internal(sample2(10, 2.99))"); + } +} diff --git a/mx.fastr/copyrights/overrides b/mx.fastr/copyrights/overrides index 3632ac9ec9..126f5fb24e 100644 --- a/mx.fastr/copyrights/overrides +++ b/mx.fastr/copyrights/overrides @@ -165,6 +165,7 @@ com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/R com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowsumFunctions.java,gnu_r_gentleman_ihaka2.copyright com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/S3DispatchFunctions.java,purdue.copyright com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java,gnu_r_sample.copyright +com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample2.java,gnu_r_gentleman_ihaka2.copyright com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Strtrim.java,gnu_r_gentleman_ihaka.copyright com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Scan.java,gnu_r_scan.copyright com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Slot.java,gnu_r_gentleman_ihaka.copyright -- GitLab