Skip to content
Snippets Groups Projects
Commit 5d83a0a7 authored by stepan's avatar stepan
Browse files

Interface change in RandGenerationFunctions that should address performance

parent 20a08054
No related branches found
No related tags found
No related merge requests found
......@@ -37,11 +37,10 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
import com.oracle.truffle.r.runtime.rng.RRNG;
import com.oracle.truffle.r.runtime.rng.RandomNumberNode;
public final class RandGenerationFunctions {
public static final double ERR_NA = RRuntime.DOUBLE_NA;
private static final RDouble DUMMY_VECTOR = RDouble.valueOf(1);
private RandGenerationFunctions() {
......@@ -50,46 +49,57 @@ public final class RandGenerationFunctions {
// inspired by the DEFRAND{X}_REAL and DEFRAND{X}_INT macros in GnuR
public interface RandFunction3_Int {
int evaluate(double a, double b, double c);
public interface RandFunction {
/**
* Allows to execute any initialization logic before the main loop that generates the
* resulting vector. This is place where the function should generate necessary random
* values if possible.
*/
default void init(int resultLength, RandomNumberNode randNode) {
RRNG.getRNGState();
}
default void finish() {
RRNG.putRNGState();
}
}
public interface RandFunction3_Int extends RandFunction {
int evaluate(int index, double a, double b, double c, RandomNumberNode randomNode);
}
public interface RandFunction2_Int extends RandFunction3_Int {
@Override
default int evaluate(double a, double b, double c) {
return evaluate(a, b);
default int evaluate(int index, double a, double b, double c, RandomNumberNode randomNode) {
return evaluate(index, a, b, randomNode);
}
int evaluate(double a, double b);
int evaluate(int index, double a, double b, RandomNumberNode randomNode);
}
public interface RandFunction2_Double {
public interface RandFunction2_Double extends RandFunction {
/**
* If returns {@code false}, {@link #evaluate(double, double)} will not be invoked.
* Opt-in possibility for random functions returning double: the infrastructure will
* preallocate array of random values and reuse it for storing the result. The random values
* will be passed to {@link #evaluate(int, double, double, double, RandomNumberNode)} as the
* 'random' argument. If this method returns {@code true} (default), the random numbers
* generation can be done in {@link #init(int, RandomNumberNode)} and {@link #finish()} or
* in {@link #evaluate(int, double, double, double, RandomNumberNode)}.
*/
boolean isValid(double a, double b);
default boolean hasCustomRandomGeneration() {
return true;
}
/**
* Is guaranteed to be preceded by invocation of {@link #isValid(double, double)} with the
* same arguments.
* Should generate the value that will be stored to the result vector under given index.
* Error is indicated by returning {@link StatsUtil#mlError()}.
*/
double evaluate(double a, double b);
}
public abstract static class RandFunction2_DoubleAdapter implements RandFunction2_Double {
@Override
public boolean isValid(double a, double b) {
return true;
}
double evaluate(int index, double a, double b, double random, RandomNumberNode randomNode);
}
static final class RandGenerationProfiles {
final BranchProfile nanResult = BranchProfile.create();
final BranchProfile errResult = BranchProfile.create();
final BranchProfile nan = BranchProfile.create();
final NACheck aCheck = NACheck.create();
final NACheck bCheck = NACheck.create();
final NACheck cCheck = NACheck.create();
final VectorLengthProfile resultVectorLengthProfile = VectorLengthProfile.create();
final LoopConditionProfile loopConditionProfile = LoopConditionProfile.createCountingProfile();
......@@ -99,7 +109,7 @@ public final class RandGenerationFunctions {
}
private static RAbstractIntVector evaluate3Int(Node node, RandFunction3_Int function, int lengthIn, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c,
RandGenerationProfiles profiles) {
RandGenerationProfiles profiles, RandomNumberNode randNode) {
int length = profiles.resultVectorLengthProfile.profile(lengthIn);
int aLength = a.getLength();
int bLength = b.getLength();
......@@ -113,41 +123,30 @@ public final class RandGenerationFunctions {
}
RNode.reportWork(node, length);
boolean complete = true;
boolean nans = false;
profiles.aCheck.enable(a);
profiles.bCheck.enable(b);
profiles.cCheck.enable(c);
int[] result = new int[length];
RRNG.getRNGState();
function.init(length, randNode);
for (int i = 0; profiles.loopConditionProfile.inject(i < length); i++) {
double aValue = a.getDataAt(i % aLength);
double bValue = b.getDataAt(i % bLength);
double cValue = c.getDataAt(i % cLength);
int value;
if (Double.isNaN(aValue) || Double.isNaN(bValue) || Double.isNaN(cValue)) {
int value = function.evaluate(i, aValue, bValue, cValue, randNode);
if (Double.isNaN(value)) {
profiles.nan.enter();
value = RRuntime.INT_NA;
if (profiles.aCheck.check(aValue) || profiles.bCheck.check(bValue) || profiles.cCheck.check(cValue)) {
complete = false;
}
} else {
value = function.evaluate(aValue, bValue, cValue);
if (Double.isNaN(value)) {
profiles.nan.enter();
nans = true;
}
nans = true;
}
result[i] = value;
}
RRNG.putRNGState();
function.finish();
if (nans) {
RError.warning(SHOW_CALLER, RError.Message.NAN_PRODUCED);
}
return RDataFactory.createIntVector(result, complete);
return RDataFactory.createIntVector(result, !nans);
}
private static RAbstractDoubleVector evaluate2Double(Node node, RandFunction2_Double function, int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RandGenerationProfiles profiles) {
private static RAbstractDoubleVector evaluate2Double(Node node, RandFunction2_Double function, int lengthIn, RAbstractDoubleVector a, RAbstractDoubleVector b, RandGenerationProfiles profiles,
RandomNumberNode randNode) {
int length = profiles.resultVectorLengthProfile.profile(lengthIn);
int aLength = a.getLength();
int bLength = b.getLength();
if (aLength == 0 || bLength == 0) {
......@@ -157,41 +156,33 @@ public final class RandGenerationFunctions {
}
RNode.reportWork(node, length);
boolean complete = true;
boolean nans = false;
profiles.aCheck.enable(a);
profiles.bCheck.enable(b);
double[] result = new double[length];
RRNG.getRNGState();
for (int i = 0; i < length; i++) {
double[] result;
if (function.hasCustomRandomGeneration()) {
function.init(length, randNode);
result = new double[length];
} else {
RRNG.getRNGState();
result = randNode.executeDouble(length);
RRNG.putRNGState();
}
for (int i = 0; profiles.loopConditionProfile.inject(i < length); i++) {
double aValue = a.getDataAt(i % aLength);
double bValue = b.getDataAt(i % bLength);
double value;
if (Double.isNaN(aValue) || Double.isNaN(bValue)) {
double value = function.evaluate(i, aValue, bValue, result[i], randNode);
if (Double.isNaN(value)) {
profiles.nan.enter();
value = RRuntime.INT_NA;
if (profiles.aCheck.check(aValue) || profiles.bCheck.check(bValue)) {
complete = false;
}
} else {
if (!function.isValid(aValue, bValue)) {
profiles.errResult.enter();
RError.warning(SHOW_CALLER, RError.Message.NA_PRODUCED);
return createVectorOf(length, Double.NaN);
}
value = function.evaluate(aValue, bValue);
if (Double.isNaN(value)) {
profiles.nan.enter();
nans = true;
}
nans = true;
}
result[i] = value;
}
RRNG.putRNGState();
if (function.hasCustomRandomGeneration()) {
function.finish();
}
if (nans) {
RError.warning(SHOW_CALLER, RError.Message.NAN_PRODUCED);
RError.warning(SHOW_CALLER, RError.Message.NA_PRODUCED);
}
return RDataFactory.createDoubleVector(result, complete);
return RDataFactory.createDoubleVector(result, !nans);
}
private static RAbstractDoubleVector createVectorOf(int length, double element) {
......@@ -248,8 +239,9 @@ public final class RandGenerationFunctions {
@Specialization
protected RAbstractIntVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c,
@Cached("create()") RandGenerationProfiles profiles) {
return evaluate3Int(this, function, convertToLength.execute(length), a, b, c, profiles);
@Cached("create()") RandGenerationProfiles profiles,
@Cached("create()") RandomNumberNode randNode) {
return evaluate3Int(this, function, convertToLength.execute(length), a, b, c, profiles, randNode);
}
}
......@@ -270,8 +262,9 @@ public final class RandGenerationFunctions {
@Specialization
protected RAbstractIntVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b,
@Cached("create()") RandGenerationProfiles profiles) {
return evaluate3Int(this, function, convertToLength.execute(length), a, b, DUMMY_VECTOR, profiles);
@Cached("create()") RandGenerationProfiles profiles,
@Cached("create()") RandomNumberNode randNode) {
return evaluate3Int(this, function, convertToLength.execute(length), a, b, DUMMY_VECTOR, profiles, randNode);
}
}
......@@ -292,8 +285,9 @@ public final class RandGenerationFunctions {
@Specialization
protected RAbstractDoubleVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b,
@Cached("create()") RandGenerationProfiles profiles) {
return evaluate2Double(this, function, convertToLength.execute(length), a, b, profiles);
@Cached("create()") RandGenerationProfiles profiles,
@Cached("create()") RandomNumberNode randNode) {
return evaluate2Double(this, function, convertToLength.execute(length), a, b, profiles, randNode);
}
}
}
......@@ -12,10 +12,9 @@
*/
package com.oracle.truffle.r.library.stats;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Int;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.rng.RRNG;
import com.oracle.truffle.r.runtime.rng.RandomNumberNode;
// transcribed from rbinom.c
......@@ -23,13 +22,12 @@ public final class Rbinom implements RandFunction2_Int {
private final Qbinom qbinom = new Qbinom();
@TruffleBoundary
private static double unifRand() {
return RRNG.unifRand();
private static double unifRand(RandomNumberNode randNode) {
return randNode.executeDouble(1)[0];
}
@Override
public int evaluate(double nin, double pp) {
public int evaluate(int index, double nin, double pp, RandomNumberNode randomNode) {
double psave = -1.0;
int nsave = -1;
......@@ -56,7 +54,7 @@ public final class Rbinom implements RandFunction2_Int {
/*
* evade integer overflow, and r == INT_MAX gave only even values
*/
return (int) qbinom.evaluate(unifRand(), r, pp, /* lower_tail */false, /* log_p */false);
return (int) qbinom.evaluate(unifRand(randomNode), r, pp, /* lower_tail */false, /* log_p */false);
}
/* else */
int n = (int) r;
......@@ -137,8 +135,8 @@ public final class Rbinom implements RandFunction2_Int {
/*-------------------------- np = n*p >= 30 : ------------------- */
while (true) {
u = unifRand() * p4;
v = unifRand();
u = unifRand(randomNode) * p4;
v = unifRand(randomNode);
/* triangular region */
if (u <= p1) {
ix = (int) (xm - p1 * v + u);
......@@ -225,7 +223,7 @@ public final class Rbinom implements RandFunction2_Int {
while (true) {
ix = 0;
f = qn;
u = unifRand();
u = unifRand(randomNode);
while (true) {
if (u < f) {
// goto finis;
......
......@@ -11,17 +11,30 @@
*/
package com.oracle.truffle.r.library.stats;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_DoubleAdapter;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double;
import com.oracle.truffle.r.runtime.rng.RRNG;
import com.oracle.truffle.r.runtime.rng.RandomNumberNode;
public final class Rnorm extends RandFunction2_DoubleAdapter {
public final class Rnorm implements RandFunction2_Double {
private static final double BIG = 134217728;
private double[] randomVals;
@Override
public double evaluate(double mu, double sigma) {
public void init(int length, RandomNumberNode randNode) {
RRNG.getRNGState();
randomVals = randNode.executeDouble(length * 2);
}
@Override
public void finish() {
RRNG.putRNGState();
}
@Override
public double evaluate(int index, double mu, double sigma, double random, RandomNumberNode randomNode) {
// TODO: GnuR invokes norm_rand to get "rand"
double u1 = (int) (BIG * RRNG.unifRand()) + RRNG.unifRand();
double u1 = (int) (BIG * randomVals[index * 2]) + randomVals[index * 2 + 1];
double rand = Random2.qnorm5(u1 / BIG, 0.0, 1.0, true, false);
return rand * sigma + mu;
}
......
......@@ -24,16 +24,20 @@ package com.oracle.truffle.r.library.stats;
import com.oracle.truffle.r.library.stats.RandGenerationFunctions.RandFunction2_Double;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.rng.RRNG;
import com.oracle.truffle.r.runtime.rng.RandomNumberNode;
public final class Runif implements RandFunction2_Double {
@Override
public boolean isValid(double min, double max) {
return RRuntime.isFinite(min) && RRuntime.isFinite(max) && max >= min;
public boolean hasCustomRandomGeneration() {
return false;
}
@Override
public double evaluate(double min, double max) {
return min + RRNG.unifRand() * (max - min);
public double evaluate(int index, double min, double max, double random, RandomNumberNode randomNode) {
if (!RRuntime.isFinite(min) || !RRuntime.isFinite(max) || max < min) {
return StatsUtil.mlError();
}
return min + random * (max - min);
}
}
......@@ -21,6 +21,13 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
*/
public class StatsUtil {
/**
* corresponds to macro {@code ML_ERR_return_NAN} in GnuR.
*/
public static double mlError() {
return Double.NaN;
}
public static final double DBLEPSILON = 2.2204460492503131e-16;
@TruffleBoundary
......@@ -150,7 +157,7 @@ public class StatsUtil {
// GNUR from log1p.c
//
@CompilationFinal private static final double[] alnrcs = {+.10378693562743769800686267719098e+1, -.13364301504908918098766041553133e+0, +.19408249135520563357926199374750e-1,
@CompilationFinal(dimensions = 1) private static final double[] alnrcs = {+.10378693562743769800686267719098e+1, -.13364301504908918098766041553133e+0, +.19408249135520563357926199374750e-1,
-.30107551127535777690376537776592e-2, +.48694614797154850090456366509137e-3, -.81054881893175356066809943008622e-4, +.13778847799559524782938251496059e-4,
-.23802210894358970251369992914935e-5, +.41640416213865183476391859901989e-6, -.73595828378075994984266837031998e-7, +.13117611876241674949152294345011e-7,
-.23546709317742425136696092330175e-8, +.42522773276034997775638052962567e-9, -.77190894134840796826108107493300e-10, +.14075746481359069909215356472191e-10,
......
/*
* Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
......@@ -33,4 +33,12 @@ public final class RandomNumberNode extends RBaseNode {
public double[] executeDouble(int count) {
return generatorClassProfile.profile(generatorProfile.profile(RRNG.currentGenerator())).genrandDouble(count);
}
public double executeSingleDouble() {
return generatorClassProfile.profile(generatorProfile.profile(RRNG.currentGenerator())).genrandDouble(1)[0];
}
public static RandomNumberNode create() {
return new RandomNumberNode();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment