diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java index 13b8010e51b2e105967db0804bb68824982d657a..08222239e02b707fa60c76eb6fcb6dbbf9e71e64 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java @@ -3,33 +3,15 @@ * Version 2. You may review the terms of this license at * http://www.gnu.org/licenses/gpl-2.0.html * - * Copyright (c) 2012-2014, Purdue University + * Copyright (c) 1995, 1996, Robert Gentleman and Ross Ihaka + * Copyright (c) 1997--2012, The R Core Team + * Copyright (c) 2003--2008, The R Foundation + * Copyright (c) 2012-2013, Purdue University * Copyright (c) 2013, 2014, Oracle and/or its affiliates * * All rights reserved. */ -/* - * R : A Computer Language for Statistical Data Analysis - * Copyright (C) 1995, 1996 Robert Gentleman and Ross Ihaka - * Copyright (C) 1997--2012 The R Core Team - * Copyright (C) 2003--2008 The R Foundation - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program; if not, a copy is available at - * http://www.r-project.org/Licenses/ - */ - package com.oracle.truffle.r.nodes.builtin.base; import com.oracle.truffle.api.*; @@ -55,37 +37,39 @@ public abstract class Sample extends RBuiltinNode { @Specialization(order = 10, guards = "invalidFirstArgument") @SuppressWarnings("unused") - public RIntVector doSampleInvalidFirstArg(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + public RIntVector doSampleInvalidFirstArg(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { CompilerDirectives.transferToInterpreter(); throw RError.getInvalidFirstArgument(getEncapsulatingSourceSection()); } @Specialization(order = 20, guards = "invalidProb") @SuppressWarnings("unused") - public RIntVector doSampleInvalidProb(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + public RIntVector doSampleInvalidProb(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { CompilerDirectives.transferToInterpreter(); throw RError.getIncorrectNumProb(getEncapsulatingSourceSection()); } @Specialization(order = 30, guards = "largerPopulation") @SuppressWarnings("unused") - public RIntVector doSampleLargerPopulation(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + public RIntVector doSampleLargerPopulation(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { CompilerDirectives.transferToInterpreter(); throw RError.getLargerPopu(getEncapsulatingSourceSection()); } @Specialization(order = 40, guards = "invalidSizeArgument") @SuppressWarnings("unused") - public RIntVector doSampleInvalidSize(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + public RIntVector doSampleInvalidSize(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { CompilerDirectives.transferToInterpreter(); throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(size)); } @Specialization(order = 1, guards = {"!invalidFirstArgument", "!invalidProb", "!largerPopulation", "!invalidSizeArgument", "withReplacement"}) - public RIntVector doSampleWithReplacement(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + public RIntVector doSampleWithReplacement(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { + // The following code is transcribed from GNU R src/main/random.c lines 493-501 in + // function do_sample. double[] probArray = prob.getDataCopy(); - fixupProbability(probArray, x, size, isReaptable); + fixupProbability(probArray, x, size, isRepeatable); int nc = 0; for (double aProb : probArray) { if (x * aProb > 0.1) { @@ -101,39 +85,41 @@ public abstract class Sample extends RBuiltinNode { } @Specialization(order = 2, guards = {"!invalidFirstArgument", "!invalidProb", "!largerPopulation", "!invalidSizeArgument", "!withReplacement"}) - public RIntVector doSampleNoReplacement(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + public RIntVector doSampleNoReplacement(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { double[] probArray = prob.getDataCopy(); - fixupProbability(probArray, x, size, isReaptable); + fixupProbability(probArray, x, size, isRepeatable); return RDataFactory.createIntVector(probSampleWithoutReplace(x, probArray, size), RDataFactory.COMPLETE_VECTOR); } @SuppressWarnings("unused") @Specialization(order = 50, guards = "invalidFirstArgumentNullProb") - public RIntVector doSampleInvalidFirstArgument(final int x, final int size, final byte isReaptable, final RNull prob) { + public RIntVector doSampleInvalidFirstArgument(final int x, final int size, final byte isRepeatable, final RNull prob) { CompilerDirectives.transferToInterpreter(); throw RError.getInvalidFirstArgument(getEncapsulatingSourceSection()); } @SuppressWarnings("unused") @Specialization(order = 60, guards = "invalidSizeArgument") - public RIntVector doSampleInvalidSizeArgument(final int x, final int size, final byte isReaptable, final RNull prob) { + public RIntVector doSampleInvalidSizeArgument(final int x, final int size, final byte isRepeatable, final RNull prob) { CompilerDirectives.transferToInterpreter(); throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(size)); } @SuppressWarnings("unused") @Specialization(order = 70, guards = "largerPopulation") - public RIntVector doSampleInvalidLargerPopulation(final int x, final int size, final byte isReaptable, final RNull prob) { + public RIntVector doSampleInvalidLargerPopulation(final int x, final int size, final byte isRepeatable, final RNull prob) { CompilerDirectives.transferToInterpreter(); throw RError.getIncorrectNumProb(getEncapsulatingSourceSection()); } @Specialization(order = 80, guards = {"!invalidFirstArgumentNullProb", "!invalidSizeArgument", "!largerPopulation"}) - public RIntVector doSample(final int x, final int size, final byte isReaptable, @SuppressWarnings("unused") final RNull prob) { + public RIntVector doSample(final int x, final int size, final byte isRepeatable, @SuppressWarnings("unused") final RNull prob) { // TODO:Add support of long integers. + // The following code is transcribed from GNU R src/main/random.c lines 533-545 in + // function do_sample. int[] result = new int[size]; /* avoid allocation for a single sample */ - if (isReaptable == RRuntime.LOGICAL_TRUE || size < 2) { + if (isRepeatable == RRuntime.LOGICAL_TRUE || size < 2) { for (int i = 0; i < size; i++) { result[i] = (int) (x * RRNG.unifRand() + 1); } @@ -154,12 +140,13 @@ public abstract class Sample extends RBuiltinNode { @SuppressWarnings("unused") @Specialization(order = 100, guards = "invalidIsRepeatable") - public RIntVector doSampleInvalidIsRepeatable(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + public RIntVector doSampleInvalidIsRepeatable(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { CompilerDirectives.transferToInterpreter(); - throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(isReaptable)); + throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(isRepeatable)); } private void fixupProbability(double[] probArray, int x, int size, byte isRepeatable) { + // The following code is transcribed from GNU R src/main/random.c lines 429-449 int nonZeroProbCount = 0; double probSum = 0; for (double aProb : probArray) { @@ -183,51 +170,52 @@ public abstract class Sample extends RBuiltinNode { } @SuppressWarnings("unused") - protected static boolean invalidFirstArgumentNullProb(final int x, final int size, final byte isReaptable, final RNull prob) { + protected static boolean invalidFirstArgumentNullProb(final int x, final int size, final byte isRepeatable, final RNull prob) { return !RRuntime.isFinite(x) || x < 0 || x > 4.5e15 || (size > 0 && x == 0); } @SuppressWarnings("unused") - protected static boolean invalidSizeArgument(final int x, final int size, final byte isReaptable, final RNull prob) { + protected static boolean invalidSizeArgument(final int x, final int size, final byte isRepeatable, final RNull prob) { return RRuntime.isNA(size) || size < 0; } @SuppressWarnings("unused") - protected static boolean largerPopulation(final int x, final int size, final byte isReaptable, final RNull prob) { - return isReaptable == RRuntime.LOGICAL_FALSE && size > x; + protected static boolean largerPopulation(final int x, final int size, final byte isRepeatable, final RNull prob) { + return isRepeatable == RRuntime.LOGICAL_FALSE && size > x; } @SuppressWarnings("unused") - protected static boolean invalidFirstArgument(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + protected static boolean invalidFirstArgument(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { return RRuntime.isNA(x) || x < 0 || (size > 0 && x == 0); } @SuppressWarnings("unused") - protected static boolean invalidSizeArgument(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + protected static boolean invalidSizeArgument(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { return RRuntime.isNA(size) || size < 0; } @SuppressWarnings("unused") - protected static boolean invalidProb(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { + protected static boolean invalidProb(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { return prob.getLength() != x; } @SuppressWarnings("unused") - protected static boolean withReplacement(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { - return isReaptable == RRuntime.LOGICAL_TRUE; + protected static boolean withReplacement(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { + return isRepeatable == RRuntime.LOGICAL_TRUE; } @SuppressWarnings("unused") - protected static boolean largerPopulation(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { - return isReaptable == RRuntime.LOGICAL_FALSE && size > x; + protected static boolean largerPopulation(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { + return isRepeatable == RRuntime.LOGICAL_FALSE && size > x; } @SuppressWarnings("unused") - protected static boolean invalidIsRepeatable(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { - return RRuntime.isNA(isReaptable); + protected static boolean invalidIsRepeatable(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) { + return RRuntime.isNA(isRepeatable); } private int[] probSampleReplace(int n, double[] probArray, int resultSize) { + // The following code is transcribed from GNU R src/main/random.c lines 309-335 int[] result = new int[resultSize]; int[] perm = new int[n]; @@ -252,6 +240,7 @@ public abstract class Sample extends RBuiltinNode { } private int[] probSampleWithoutReplace(int n, double[] probArray, int resultSize) { + // The following code is transcribed from GNU R src/main/random.c lines 396-428 int[] ans = new int[resultSize]; int[] perm = new int[n]; for (int i = 0; i < n; i++) { @@ -280,14 +269,14 @@ public abstract class Sample extends RBuiltinNode { } private void buildheap(double[] keys, int[] values) { - for (int i = keys.length / 2; i >= 0; i--) { + for (int i = (keys.length >> 1); i >= 0; i--) { minHeapify(keys, i, keys.length, values); } } private void minHeapify(double[] keys, int currentIndex, int heapSize, int[] values) { - int leftChildIndex = 2 * currentIndex; - int rightChildIndex = 2 * currentIndex + 1; + int leftChildIndex = currentIndex << 1; + int rightChildIndex = leftChildIndex + 1; int lowestElementIndex = currentIndex; if (leftChildIndex < heapSize && keys[leftChildIndex] < keys[currentIndex]) { lowestElementIndex = leftChildIndex;