Skip to content
Snippets Groups Projects
Commit 5b063d82 authored by Prahlad Joshi's avatar Prahlad Joshi
Browse files

Incorporating review comments

parent fd0d1dec
Branches
No related tags found
No related merge requests found
...@@ -3,33 +3,15 @@ ...@@ -3,33 +3,15 @@
* Version 2. You may review the terms of this license at * Version 2. You may review the terms of this license at
* http://www.gnu.org/licenses/gpl-2.0.html * http://www.gnu.org/licenses/gpl-2.0.html
* *
* Copyright (c) 2012-2014, Purdue University * Copyright (c) 1995, 1996, Robert Gentleman and Ross Ihaka
* Copyright (c) 1997--2012, The R Core Team
* Copyright (c) 2003--2008, The R Foundation
* Copyright (c) 2012-2013, Purdue University
* Copyright (c) 2013, 2014, Oracle and/or its affiliates * Copyright (c) 2013, 2014, Oracle and/or its affiliates
* *
* All rights reserved. * All rights reserved.
*/ */
/*
* R : A Computer Language for Statistical Data Analysis
* Copyright (C) 1995, 1996 Robert Gentleman and Ross Ihaka
* Copyright (C) 1997--2012 The R Core Team
* Copyright (C) 2003--2008 The R Foundation
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, a copy is available at
* http://www.r-project.org/Licenses/
*/
package com.oracle.truffle.r.nodes.builtin.base; package com.oracle.truffle.r.nodes.builtin.base;
import com.oracle.truffle.api.*; import com.oracle.truffle.api.*;
...@@ -55,37 +37,39 @@ public abstract class Sample extends RBuiltinNode { ...@@ -55,37 +37,39 @@ public abstract class Sample extends RBuiltinNode {
@Specialization(order = 10, guards = "invalidFirstArgument") @Specialization(order = 10, guards = "invalidFirstArgument")
@SuppressWarnings("unused") @SuppressWarnings("unused")
public RIntVector doSampleInvalidFirstArg(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { public RIntVector doSampleInvalidFirstArg(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getInvalidFirstArgument(getEncapsulatingSourceSection()); throw RError.getInvalidFirstArgument(getEncapsulatingSourceSection());
} }
@Specialization(order = 20, guards = "invalidProb") @Specialization(order = 20, guards = "invalidProb")
@SuppressWarnings("unused") @SuppressWarnings("unused")
public RIntVector doSampleInvalidProb(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { public RIntVector doSampleInvalidProb(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getIncorrectNumProb(getEncapsulatingSourceSection()); throw RError.getIncorrectNumProb(getEncapsulatingSourceSection());
} }
@Specialization(order = 30, guards = "largerPopulation") @Specialization(order = 30, guards = "largerPopulation")
@SuppressWarnings("unused") @SuppressWarnings("unused")
public RIntVector doSampleLargerPopulation(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { public RIntVector doSampleLargerPopulation(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getLargerPopu(getEncapsulatingSourceSection()); throw RError.getLargerPopu(getEncapsulatingSourceSection());
} }
@Specialization(order = 40, guards = "invalidSizeArgument") @Specialization(order = 40, guards = "invalidSizeArgument")
@SuppressWarnings("unused") @SuppressWarnings("unused")
public RIntVector doSampleInvalidSize(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { public RIntVector doSampleInvalidSize(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(size)); throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(size));
} }
@Specialization(order = 1, guards = {"!invalidFirstArgument", "!invalidProb", "!largerPopulation", "!invalidSizeArgument", "withReplacement"}) @Specialization(order = 1, guards = {"!invalidFirstArgument", "!invalidProb", "!largerPopulation", "!invalidSizeArgument", "withReplacement"})
public RIntVector doSampleWithReplacement(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { public RIntVector doSampleWithReplacement(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
// The following code is transcribed from GNU R src/main/random.c lines 493-501 in
// function do_sample.
double[] probArray = prob.getDataCopy(); double[] probArray = prob.getDataCopy();
fixupProbability(probArray, x, size, isReaptable); fixupProbability(probArray, x, size, isRepeatable);
int nc = 0; int nc = 0;
for (double aProb : probArray) { for (double aProb : probArray) {
if (x * aProb > 0.1) { if (x * aProb > 0.1) {
...@@ -101,39 +85,41 @@ public abstract class Sample extends RBuiltinNode { ...@@ -101,39 +85,41 @@ public abstract class Sample extends RBuiltinNode {
} }
@Specialization(order = 2, guards = {"!invalidFirstArgument", "!invalidProb", "!largerPopulation", "!invalidSizeArgument", "!withReplacement"}) @Specialization(order = 2, guards = {"!invalidFirstArgument", "!invalidProb", "!largerPopulation", "!invalidSizeArgument", "!withReplacement"})
public RIntVector doSampleNoReplacement(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { public RIntVector doSampleNoReplacement(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
double[] probArray = prob.getDataCopy(); double[] probArray = prob.getDataCopy();
fixupProbability(probArray, x, size, isReaptable); fixupProbability(probArray, x, size, isRepeatable);
return RDataFactory.createIntVector(probSampleWithoutReplace(x, probArray, size), RDataFactory.COMPLETE_VECTOR); return RDataFactory.createIntVector(probSampleWithoutReplace(x, probArray, size), RDataFactory.COMPLETE_VECTOR);
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Specialization(order = 50, guards = "invalidFirstArgumentNullProb") @Specialization(order = 50, guards = "invalidFirstArgumentNullProb")
public RIntVector doSampleInvalidFirstArgument(final int x, final int size, final byte isReaptable, final RNull prob) { public RIntVector doSampleInvalidFirstArgument(final int x, final int size, final byte isRepeatable, final RNull prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getInvalidFirstArgument(getEncapsulatingSourceSection()); throw RError.getInvalidFirstArgument(getEncapsulatingSourceSection());
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Specialization(order = 60, guards = "invalidSizeArgument") @Specialization(order = 60, guards = "invalidSizeArgument")
public RIntVector doSampleInvalidSizeArgument(final int x, final int size, final byte isReaptable, final RNull prob) { public RIntVector doSampleInvalidSizeArgument(final int x, final int size, final byte isRepeatable, final RNull prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(size)); throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(size));
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Specialization(order = 70, guards = "largerPopulation") @Specialization(order = 70, guards = "largerPopulation")
public RIntVector doSampleInvalidLargerPopulation(final int x, final int size, final byte isReaptable, final RNull prob) { public RIntVector doSampleInvalidLargerPopulation(final int x, final int size, final byte isRepeatable, final RNull prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getIncorrectNumProb(getEncapsulatingSourceSection()); throw RError.getIncorrectNumProb(getEncapsulatingSourceSection());
} }
@Specialization(order = 80, guards = {"!invalidFirstArgumentNullProb", "!invalidSizeArgument", "!largerPopulation"}) @Specialization(order = 80, guards = {"!invalidFirstArgumentNullProb", "!invalidSizeArgument", "!largerPopulation"})
public RIntVector doSample(final int x, final int size, final byte isReaptable, @SuppressWarnings("unused") final RNull prob) { public RIntVector doSample(final int x, final int size, final byte isRepeatable, @SuppressWarnings("unused") final RNull prob) {
// TODO:Add support of long integers. // TODO:Add support of long integers.
// The following code is transcribed from GNU R src/main/random.c lines 533-545 in
// function do_sample.
int[] result = new int[size]; int[] result = new int[size];
/* avoid allocation for a single sample */ /* avoid allocation for a single sample */
if (isReaptable == RRuntime.LOGICAL_TRUE || size < 2) { if (isRepeatable == RRuntime.LOGICAL_TRUE || size < 2) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
result[i] = (int) (x * RRNG.unifRand() + 1); result[i] = (int) (x * RRNG.unifRand() + 1);
} }
...@@ -154,12 +140,13 @@ public abstract class Sample extends RBuiltinNode { ...@@ -154,12 +140,13 @@ public abstract class Sample extends RBuiltinNode {
@SuppressWarnings("unused") @SuppressWarnings("unused")
@Specialization(order = 100, guards = "invalidIsRepeatable") @Specialization(order = 100, guards = "invalidIsRepeatable")
public RIntVector doSampleInvalidIsRepeatable(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { public RIntVector doSampleInvalidIsRepeatable(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
CompilerDirectives.transferToInterpreter(); CompilerDirectives.transferToInterpreter();
throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(isReaptable)); throw RError.getInvalidArgument(getEncapsulatingSourceSection(), RRuntime.toString(isRepeatable));
} }
private void fixupProbability(double[] probArray, int x, int size, byte isRepeatable) { private void fixupProbability(double[] probArray, int x, int size, byte isRepeatable) {
// The following code is transcribed from GNU R src/main/random.c lines 429-449
int nonZeroProbCount = 0; int nonZeroProbCount = 0;
double probSum = 0; double probSum = 0;
for (double aProb : probArray) { for (double aProb : probArray) {
...@@ -183,51 +170,52 @@ public abstract class Sample extends RBuiltinNode { ...@@ -183,51 +170,52 @@ public abstract class Sample extends RBuiltinNode {
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean invalidFirstArgumentNullProb(final int x, final int size, final byte isReaptable, final RNull prob) { protected static boolean invalidFirstArgumentNullProb(final int x, final int size, final byte isRepeatable, final RNull prob) {
return !RRuntime.isFinite(x) || x < 0 || x > 4.5e15 || (size > 0 && x == 0); return !RRuntime.isFinite(x) || x < 0 || x > 4.5e15 || (size > 0 && x == 0);
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean invalidSizeArgument(final int x, final int size, final byte isReaptable, final RNull prob) { protected static boolean invalidSizeArgument(final int x, final int size, final byte isRepeatable, final RNull prob) {
return RRuntime.isNA(size) || size < 0; return RRuntime.isNA(size) || size < 0;
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean largerPopulation(final int x, final int size, final byte isReaptable, final RNull prob) { protected static boolean largerPopulation(final int x, final int size, final byte isRepeatable, final RNull prob) {
return isReaptable == RRuntime.LOGICAL_FALSE && size > x; return isRepeatable == RRuntime.LOGICAL_FALSE && size > x;
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean invalidFirstArgument(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { protected static boolean invalidFirstArgument(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
return RRuntime.isNA(x) || x < 0 || (size > 0 && x == 0); return RRuntime.isNA(x) || x < 0 || (size > 0 && x == 0);
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean invalidSizeArgument(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { protected static boolean invalidSizeArgument(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
return RRuntime.isNA(size) || size < 0; return RRuntime.isNA(size) || size < 0;
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean invalidProb(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { protected static boolean invalidProb(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
return prob.getLength() != x; return prob.getLength() != x;
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean withReplacement(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { protected static boolean withReplacement(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
return isReaptable == RRuntime.LOGICAL_TRUE; return isRepeatable == RRuntime.LOGICAL_TRUE;
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean largerPopulation(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { protected static boolean largerPopulation(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
return isReaptable == RRuntime.LOGICAL_FALSE && size > x; return isRepeatable == RRuntime.LOGICAL_FALSE && size > x;
} }
@SuppressWarnings("unused") @SuppressWarnings("unused")
protected static boolean invalidIsRepeatable(final int x, final int size, final byte isReaptable, final RDoubleVector prob) { protected static boolean invalidIsRepeatable(final int x, final int size, final byte isRepeatable, final RDoubleVector prob) {
return RRuntime.isNA(isReaptable); return RRuntime.isNA(isRepeatable);
} }
private int[] probSampleReplace(int n, double[] probArray, int resultSize) { private int[] probSampleReplace(int n, double[] probArray, int resultSize) {
// The following code is transcribed from GNU R src/main/random.c lines 309-335
int[] result = new int[resultSize]; int[] result = new int[resultSize];
int[] perm = new int[n]; int[] perm = new int[n];
...@@ -252,6 +240,7 @@ public abstract class Sample extends RBuiltinNode { ...@@ -252,6 +240,7 @@ public abstract class Sample extends RBuiltinNode {
} }
private int[] probSampleWithoutReplace(int n, double[] probArray, int resultSize) { private int[] probSampleWithoutReplace(int n, double[] probArray, int resultSize) {
// The following code is transcribed from GNU R src/main/random.c lines 396-428
int[] ans = new int[resultSize]; int[] ans = new int[resultSize];
int[] perm = new int[n]; int[] perm = new int[n];
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
...@@ -280,14 +269,14 @@ public abstract class Sample extends RBuiltinNode { ...@@ -280,14 +269,14 @@ public abstract class Sample extends RBuiltinNode {
} }
private void buildheap(double[] keys, int[] values) { private void buildheap(double[] keys, int[] values) {
for (int i = keys.length / 2; i >= 0; i--) { for (int i = (keys.length >> 1); i >= 0; i--) {
minHeapify(keys, i, keys.length, values); minHeapify(keys, i, keys.length, values);
} }
} }
private void minHeapify(double[] keys, int currentIndex, int heapSize, int[] values) { private void minHeapify(double[] keys, int currentIndex, int heapSize, int[] values) {
int leftChildIndex = 2 * currentIndex; int leftChildIndex = currentIndex << 1;
int rightChildIndex = 2 * currentIndex + 1; int rightChildIndex = leftChildIndex + 1;
int lowestElementIndex = currentIndex; int lowestElementIndex = currentIndex;
if (leftChildIndex < heapSize && keys[leftChildIndex] < keys[currentIndex]) { if (leftChildIndex < heapSize && keys[leftChildIndex] < keys[currentIndex]) {
lowestElementIndex = leftChildIndex; lowestElementIndex = leftChildIndex;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment