Skip to content
Snippets Groups Projects
Commit ed3e1245 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert RNG and stats functions to VectorAccess

parent ba3a4eb7
No related branches found
No related tags found
No related merge requests found
...@@ -27,18 +27,15 @@ import com.oracle.truffle.api.profiles.ValueProfile; ...@@ -27,18 +27,15 @@ import com.oracle.truffle.api.profiles.ValueProfile;
import com.oracle.truffle.r.nodes.attributes.GetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.GetFixedAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
import com.oracle.truffle.r.nodes.function.opt.ReuseNonSharedNode;
import com.oracle.truffle.r.nodes.function.opt.UpdateShareableChildValueNode; import com.oracle.truffle.r.nodes.function.opt.UpdateShareableChildValueNode;
import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.nodes.ReadAccessor; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.SetDataAt; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandomNumberProvider; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandomNumberProvider;
import com.oracle.truffle.r.runtime.nmath.distr.RMultinom; import com.oracle.truffle.r.runtime.nmath.distr.RMultinom;
import com.oracle.truffle.r.runtime.nmath.distr.Rbinom; import com.oracle.truffle.r.runtime.nmath.distr.Rbinom;
...@@ -49,8 +46,15 @@ import com.oracle.truffle.r.runtime.rng.RRNG; ...@@ -49,8 +46,15 @@ import com.oracle.truffle.r.runtime.rng.RRNG;
* Implements the vectorization of {@link RMultinom}. * Implements the vectorization of {@link RMultinom}.
*/ */
public abstract class RMultinomNode extends RExternalBuiltinNode.Arg3 { public abstract class RMultinomNode extends RExternalBuiltinNode.Arg3 {
private final Rbinom rbinom = new Rbinom(); private final Rbinom rbinom = new Rbinom();
private final ValueProfile randGeneratorClassProfile = ValueProfile.createClassProfile();
private final ConditionProfile hasAttributesProfile = ConditionProfile.createBinaryProfile();
@Child private UpdateShareableChildValueNode updateSharedAttributeNode = UpdateShareableChildValueNode.create();
@Child private GetFixedAttributeNode getNamesNode = GetFixedAttributeNode.createNames();
@Child private SetFixedAttributeNode setDimNamesNode = SetFixedAttributeNode.createDimNames();
public static RMultinomNode create() { public static RMultinomNode create() {
return RMultinomNodeGen.create(); return RMultinomNodeGen.create();
} }
...@@ -68,63 +72,69 @@ public abstract class RMultinomNode extends RExternalBuiltinNode.Arg3 { ...@@ -68,63 +72,69 @@ public abstract class RMultinomNode extends RExternalBuiltinNode.Arg3 {
} }
@Specialization @Specialization
protected RIntVector doMultinom(int n, int size, RAbstractDoubleVector probsVec, protected RIntVector doMultinom(int n, int size, RAbstractDoubleVector probs,
@Cached("create()") VectorReadAccess.Double probsAccess, @Cached("probs.access()") VectorAccess probsAccess) {
@Cached("create()") SetDataAt.Double probsSetter, try (SequentialIterator probsIter = probsAccess.access(probs)) {
@Cached("create()") ReuseNonSharedNode reuseNonSharedNode, double sum = 0.0;
@Cached("createClassProfile()") ValueProfile randGeneratorClassProfile, while (probsAccess.next(probsIter)) {
@Cached("createBinaryProfile()") ConditionProfile hasAttributesProfile, double prob = probsAccess.getDouble(probsIter);
@Cached("create()") UpdateShareableChildValueNode updateSharedAttributeNode, if (!Double.isFinite(prob)) {
@Cached("createNames()") GetFixedAttributeNode getNamesNode, throw error(NA_IN_PROB_VECTOR);
@Cached("createDimNames()") SetFixedAttributeNode setDimNamesNode) { }
RDoubleVector nonSharedProbs = ((RAbstractDoubleVector) reuseNonSharedNode.execute(probsVec)).materialize(); if (prob < 0.0) {
ReadAccessor.Double probs = new ReadAccessor.Double(nonSharedProbs, probsAccess); throw error(NEGATIVE_PROBABILITY);
fixupProb(nonSharedProbs, probs, probsSetter); }
sum += prob;
RRNG.getRNGState(); }
RandomNumberProvider rand = new RandomNumberProvider(randGeneratorClassProfile.profile(RRNG.currentGenerator()), RRNG.currentNormKind()); if (sum == 0) {
int k = nonSharedProbs.getLength(); throw error(NO_POSITIVE_PROBABILITIES);
int[] result = new int[k * n]; }
boolean isComplete = true;
for (int i = 0, ik = 0; i < n; i++, ik += k) {
isComplete &= RMultinom.rmultinom(size, probs, k, result, ik, rand, rbinom);
}
RRNG.putRNGState();
// take names from probVec (if any) as row names in the result RRNG.getRNGState();
RIntVector resultVec = RDataFactory.createIntVector(result, isComplete, new int[]{k, n}); RandomNumberProvider rand = new RandomNumberProvider(randGeneratorClassProfile.profile(RRNG.currentGenerator()), RRNG.currentNormKind());
if (hasAttributesProfile.profile(probsVec.getAttributes() != null)) { int[] result = new int[probsAccess.getLength(probsIter) * n];
Object probsNames = getNamesNode.execute(probsVec.getAttributes()); if (size > 0) {
updateSharedAttributeNode.execute(probsVec, probsNames); for (int i = 0, ik = 0; i < n; i++, ik += probsAccess.getLength(probsIter)) {
Object[] dimnamesData = new Object[]{probsNames, RNull.instance}; double currentSum = sum;
setDimNamesNode.execute(resultVec.getAttributes(), RDataFactory.createList(dimnamesData)); int currentSize = size;
} /* Generate the first K-1 obs. via binomials */
return resultVec; probsAccess.reset(probsIter);
} for (int k = 0; probsAccess.next(probsIter) && k < probsAccess.getLength(probsIter) - 1; k++) {
/* (p_tot, n) are for "remaining binomial" */
/* LDOUBLE */double probK = probsAccess.getDouble(probsIter);
if (probK != 0.) {
double pp = probK / currentSum;
int value = (pp < 1.) ? (int) rbinom.execute(currentSize, pp, rand) : currentSize;
/*
* >= 1; > 1 happens because of rounding
*/
result[ik + k] = value;
currentSize -= value;
} else {
result[ik + k] = 0;
}
if (n <= 0) {
/* we have all */
break;
}
/* i.e. = sum(prob[(k+1):K]) */
currentSum -= probK;
}
private void fixupProb(RDoubleVector p, ReadAccessor.Double pAccess, SetDataAt.Double pSetter) { result[ik + probsAccess.getLength(probsIter) - 1] = currentSize;
double sum = 0.0; }
int npos = 0;
int pLength = p.getLength();
for (int i = 0; i < pLength; i++) {
double prob = pAccess.getDataAt(i);
if (!Double.isFinite(prob)) {
throw error(NA_IN_PROB_VECTOR);
} }
if (prob < 0.0) { RRNG.putRNGState();
throw error(NEGATIVE_PROBABILITY);
} // take names from probVec (if any) as row names in the result
if (prob > 0.0) { RIntVector resultVec = RDataFactory.createIntVector(result, true, new int[]{probsAccess.getLength(probsIter), n});
npos++; if (hasAttributesProfile.profile(probs.getAttributes() != null)) {
sum += prob; Object probsNames = getNamesNode.execute(probs.getAttributes());
updateSharedAttributeNode.execute(probs, probsNames);
Object[] dimnamesData = new Object[]{probsNames, RNull.instance};
setDimNamesNode.execute(resultVec.getAttributes(), RDataFactory.createList(dimnamesData));
} }
} return resultVec;
if (npos == 0) {
throw error(NO_POSITIVE_PROBABILITIES);
}
for (int i = 0; i < pLength; i++) {
double prob = pAccess.getDataAt(i);
pSetter.setDataAt(p, pAccess.getStore(), i, prob / sum);
} }
} }
} }
...@@ -46,8 +46,9 @@ import com.oracle.truffle.r.runtime.data.RDouble; ...@@ -46,8 +46,9 @@ import com.oracle.truffle.r.runtime.data.RDouble;
import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.nodes.ReadAccessor; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function2_1; import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function2_1;
import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function2_2; import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function2_2;
import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function3_1; import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function3_1;
...@@ -374,10 +375,11 @@ public final class StatsFunctionsNodes { ...@@ -374,10 +375,11 @@ public final class StatsFunctionsNodes {
casts.arg(6).asDoubleVector().findFirst(); casts.arg(6).asDoubleVector().findFirst();
} }
@Specialization @Specialization(guards = {"xAccess.supports(x)", "yAccess.supports(y)", "vAccess.supports(v)"})
protected RDoubleVector approx(RAbstractDoubleVector x, RAbstractDoubleVector y, RAbstractDoubleVector v, int method, double yl, double yr, double f, protected RDoubleVector approx(RAbstractDoubleVector x, RAbstractDoubleVector y, RAbstractDoubleVector v, int method, double yl, double yr, double f,
@Cached("create()") VectorReadAccess.Double xAccess, @Cached("x.access()") VectorAccess xAccess,
@Cached("create()") VectorReadAccess.Double yAccess) { @Cached("y.access()") VectorAccess yAccess,
@Cached("v.access()") VectorAccess vAccess) {
int nx = x.getLength(); int nx = x.getLength();
int nout = v.getLength(); int nout = v.getLength();
double[] yout = new double[nout]; double[] yout = new double[nout];
...@@ -390,14 +392,21 @@ public final class StatsFunctionsNodes { ...@@ -390,14 +392,21 @@ public final class StatsFunctionsNodes {
apprMeth.yhigh = yr; apprMeth.yhigh = yr;
naCheck.enable(true); naCheck.enable(true);
ReadAccessor.Double xAccessor = new ReadAccessor.Double(x, xAccess); try (RandomIterator xIter = xAccess.randomAccess(x); RandomIterator yIter = yAccess.randomAccess(y); SequentialIterator vIter = vAccess.access(v)) {
ReadAccessor.Double yAccessor = new ReadAccessor.Double(y, yAccess); int i = 0;
for (int i = 0; i < nout; i++) { while (vAccess.next(vIter)) {
double xouti = v.getDataAt(i); double xouti = vAccess.getDouble(vIter);
yout[i] = RRuntime.isNAorNaN(xouti) ? xouti : approx1(xouti, xAccessor, yAccessor, nx, apprMeth); yout[i] = RRuntime.isNAorNaN(xouti) ? xouti : approx1(xouti, xAccess, xIter, yAccess, yIter, nx, apprMeth);
naCheck.check(yout[i]); naCheck.check(yout[i]);
i++;
}
return RDataFactory.createDoubleVector(yout, naCheck.neverSeenNA());
} }
return RDataFactory.createDoubleVector(yout, naCheck.neverSeenNA()); }
@Specialization(replaces = "approx")
protected RDoubleVector approxGeneric(RAbstractDoubleVector x, RAbstractDoubleVector y, RAbstractDoubleVector v, int method, double yl, double yr, double f) {
return approx(x, y, v, method, yl, yr, f, x.slowPathAccess(), y.slowPathAccess(), v.slowPathAccess());
} }
private static class ApprMeth { private static class ApprMeth {
...@@ -408,32 +417,29 @@ public final class StatsFunctionsNodes { ...@@ -408,32 +417,29 @@ public final class StatsFunctionsNodes {
int kind; int kind;
} }
private static double approx1(double v, ReadAccessor.Double x, ReadAccessor.Double y, int n, private static double approx1(double v, VectorAccess xAccess, RandomIterator xIter, VectorAccess yAccess, RandomIterator yIter, int n,
ApprMeth apprMeth) { ApprMeth apprMeth) {
/* Approximate y(v), given (x,y)[i], i = 0,..,n-1 */ /* Approximate y(v), given (x,y)[i], i = 0,..,n-1 */
int i;
int j;
int ij;
if (n == 0) { if (n == 0) {
return RRuntime.DOUBLE_NA; return RRuntime.DOUBLE_NA;
} }
i = 0; int i = 0;
j = n - 1; int j = n - 1;
/* handle out-of-domain points */ /* handle out-of-domain points */
if (v < x.getDataAt(i)) { if (v < xAccess.getDouble(xIter, i)) {
return apprMeth.ylow; return apprMeth.ylow;
} }
if (v > x.getDataAt(j)) { if (v > xAccess.getDouble(xIter, j)) {
return apprMeth.yhigh; return apprMeth.yhigh;
} }
/* find the correct interval by bisection */ /* find the correct interval by bisection */
while (i < j - 1) { /* x.getDataAt(i) <= v <= x.getDataAt(j) */ while (i < j - 1) { /* x.getDataAt(i) <= v <= x.getDataAt(j) */
ij = (i + j) / 2; /* i+1 <= ij <= j-1 */ int ij = (i + j) / 2;
if (v < x.getDataAt(ij)) { /* i+1 <= ij <= j-1 */
if (v < xAccess.getDouble(xIter, ij)) {
j = ij; j = ij;
} else { } else {
i = ij; i = ij;
...@@ -444,18 +450,22 @@ public final class StatsFunctionsNodes { ...@@ -444,18 +450,22 @@ public final class StatsFunctionsNodes {
/* interpolation */ /* interpolation */
if (v == x.getDataAt(j)) { double xJ = xAccess.getDouble(xIter, j);
return y.getDataAt(j); double yJ = yAccess.getDouble(yIter, j);
if (v == xJ) {
return yJ;
} }
if (v == x.getDataAt(i)) { double xI = xAccess.getDouble(xIter, i);
return y.getDataAt(i); double yI = yAccess.getDouble(yIter, i);
if (v == xI) {
return yI;
} }
/* impossible: if(x.getDataAt(j) == x.getDataAt(i)) return y.getDataAt(i); */ /* impossible: if(x.getDataAt(j) == x.getDataAt(i)) return y.getDataAt(i); */
if (apprMeth.kind == 1) { /* linear */ if (apprMeth.kind == 1) { /* linear */
return y.getDataAt(i) + (y.getDataAt(j) - y.getDataAt(i)) * ((v - x.getDataAt(i)) / (x.getDataAt(j) - x.getDataAt(i))); return yI + (yJ - yI) * ((v - xI) / (xJ - xI));
} else { /* 2 : constant */ } else { /* 2 : constant */
return (apprMeth.f1 != 0.0 ? y.getDataAt(i) * apprMeth.f1 : 0.0) + (apprMeth.f2 != 0.0 ? y.getDataAt(j) * apprMeth.f2 : 0.0); return (apprMeth.f1 != 0.0 ? yI * apprMeth.f1 : 0.0) + (apprMeth.f2 != 0.0 ? yJ * apprMeth.f2 : 0.0);
} }
}/* approx1() */ }/* approx1() */
......
...@@ -324,43 +324,43 @@ public class CallAndExternalFunctions { ...@@ -324,43 +324,43 @@ public class CallAndExternalFunctions {
case "qnorm": case "qnorm":
return StatsFunctionsNodes.Function3_2Node.create(new Qnorm()); return StatsFunctionsNodes.Function3_2Node.create(new Qnorm());
case "rnorm": case "rnorm":
return RandFunction2Node.createDouble(new Rnorm()); return RandFunction2Node.createDouble(Rnorm::new);
case "runif": case "runif":
return RandFunction2Node.createDouble(new Runif()); return RandFunction2Node.createDouble(Runif::new);
case "rbeta": case "rbeta":
return RandFunction2Node.createDouble(new RBeta()); return RandFunction2Node.createDouble(RBeta::new);
case "rgamma": case "rgamma":
return RandFunction2Node.createDouble(new RGamma()); return RandFunction2Node.createDouble(RGamma::new);
case "rcauchy": case "rcauchy":
return RandFunction2Node.createDouble(new RCauchy()); return RandFunction2Node.createDouble(RCauchy::new);
case "rf": case "rf":
return RandFunction2Node.createDouble(new Rf()); return RandFunction2Node.createDouble(Rf::new);
case "rlogis": case "rlogis":
return RandFunction2Node.createDouble(new RLogis()); return RandFunction2Node.createDouble(RLogis::new);
case "rweibull": case "rweibull":
return RandFunction2Node.createDouble(new RWeibull()); return RandFunction2Node.createDouble(RWeibull::new);
case "rnchisq": case "rnchisq":
return RandFunction2Node.createDouble(new RNchisq()); return RandFunction2Node.createDouble(RNchisq::new);
case "rnbinom_mu": case "rnbinom_mu":
return RandFunction2Node.createDouble(new RNBinomMu()); return RandFunction2Node.createDouble(RNBinomMu::new);
case "rwilcox": case "rwilcox":
return RandFunction2Node.createInt(new RWilcox()); return RandFunction2Node.createInt(RWilcox::new);
case "rchisq": case "rchisq":
return RandFunction1Node.createDouble(new RChisq()); return RandFunction1Node.createDouble(RChisq::new);
case "rexp": case "rexp":
return RandFunction1Node.createDouble(new RExp()); return RandFunction1Node.createDouble(RExp::new);
case "rgeom": case "rgeom":
return RandFunction1Node.createInt(new RGeom()); return RandFunction1Node.createInt(RGeom::new);
case "rpois": case "rpois":
return RandFunction1Node.createInt(new RPois()); return RandFunction1Node.createInt(RPois::new);
case "rnbinom": case "rnbinom":
return RandFunction2Node.createInt(new RNBinomFunc()); return RandFunction2Node.createInt(RNBinomFunc::new);
case "rt": case "rt":
return RandFunction1Node.createDouble(new Rt()); return RandFunction1Node.createDouble(Rt::new);
case "rsignrank": case "rsignrank":
return RandFunction1Node.createInt(new RSignrank()); return RandFunction1Node.createInt(RSignrank::new);
case "rhyper": case "rhyper":
return RandFunction3Node.createInt(new RHyper()); return RandFunction3Node.createInt(RHyper::new);
case "phyper": case "phyper":
return StatsFunctionsNodes.Function4_2Node.create(new PHyper()); return StatsFunctionsNodes.Function4_2Node.create(new PHyper());
case "dhyper": case "dhyper":
...@@ -400,7 +400,7 @@ public class CallAndExternalFunctions { ...@@ -400,7 +400,7 @@ public class CallAndExternalFunctions {
case "dweibull": case "dweibull":
return StatsFunctionsNodes.Function3_1Node.create(new DWeibull()); return StatsFunctionsNodes.Function3_1Node.create(new DWeibull());
case "rbinom": case "rbinom":
return RandFunction2Node.createInt(new Rbinom()); return RandFunction2Node.createInt(Rbinom::new);
case "pbinom": case "pbinom":
return StatsFunctionsNodes.Function3_2Node.create(new Pbinom()); return StatsFunctionsNodes.Function3_2Node.create(new Pbinom());
case "pbeta": case "pbeta":
...@@ -458,7 +458,7 @@ public class CallAndExternalFunctions { ...@@ -458,7 +458,7 @@ public class CallAndExternalFunctions {
case "dt": case "dt":
return StatsFunctionsNodes.Function2_1Node.create(new Dt()); return StatsFunctionsNodes.Function2_1Node.create(new Dt());
case "rlnorm": case "rlnorm":
return RandFunction2Node.createDouble(new LogNormal.RLNorm()); return RandFunction2Node.createDouble(LogNormal.RLNorm::new);
case "dlnorm": case "dlnorm":
return StatsFunctionsNodes.Function3_1Node.create(new DLNorm()); return StatsFunctionsNodes.Function3_1Node.create(new DLNorm());
case "qlnorm": case "qlnorm":
......
...@@ -47,11 +47,11 @@ public class RandomFunctions { ...@@ -47,11 +47,11 @@ public class RandomFunctions {
} }
} }
public abstract static class RandFunction1_Double extends RandFunction2_Double { public abstract static class RandFunction1_Double extends RandFunction3_Double {
public abstract double execute(double a, RandomNumberProvider rand); public abstract double execute(double a, RandomNumberProvider rand);
@Override @Override
public final double execute(double a, double b, RandomNumberProvider rand) { public final double execute(double a, double b, double c, RandomNumberProvider rand) {
return execute(a, rand); return execute(a, rand);
} }
} }
......
...@@ -18,7 +18,8 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; ...@@ -18,7 +18,8 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.data.nodes.ReadAccessor; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandomNumberProvider; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandomNumberProvider;
public final class RMultinom { public final class RMultinom {
...@@ -32,7 +33,7 @@ public final class RMultinom { ...@@ -32,7 +33,7 @@ public final class RMultinom {
* prob[j]) , sum_j rN[j] == n, sum_j prob[j] == 1. * prob[j]) , sum_j rN[j] == n, sum_j prob[j] == 1.
*/ */
@TruffleBoundary @TruffleBoundary
public static boolean rmultinom(int nIn, ReadAccessor.Double prob, int maxK, int[] rN, int rnStartIdx, RandomNumberProvider rand, Rbinom rbinom) { public static boolean rmultinom(int nIn, SequentialIterator probsIter, VectorAccess probsAccess, double sum, int[] rN, int rnStartIdx, RandomNumberProvider rand, Rbinom rbinom) {
/* /*
* This calculation is sensitive to exact values, so we try to ensure that the calculations * This calculation is sensitive to exact values, so we try to ensure that the calculations
* are as accurate as possible so different platforms are more likely to give the same * are as accurate as possible so different platforms are more likely to give the same
...@@ -40,6 +41,7 @@ public final class RMultinom { ...@@ -40,6 +41,7 @@ public final class RMultinom {
*/ */
int n = nIn; int n = nIn;
int maxK = probsAccess.getLength(probsIter);
if (RRuntime.isNA(maxK) || maxK < 1 || RRuntime.isNA(n) || n < 0) { if (RRuntime.isNA(maxK) || maxK < 1 || RRuntime.isNA(n) || n < 0) {
if (rN.length > rnStartIdx) { if (rN.length > rnStartIdx) {
rN[rnStartIdx] = RRuntime.INT_NA; rN[rnStartIdx] = RRuntime.INT_NA;
...@@ -52,8 +54,9 @@ public final class RMultinom { ...@@ -52,8 +54,9 @@ public final class RMultinom {
* shorter and drop that check ! * shorter and drop that check !
*/ */
/* LDOUBLE */double pTot = 0.; /* LDOUBLE */double pTot = 0.;
for (int k = 0; k < maxK; k++) { probsAccess.reset(probsIter);
double pp = prob.getDataAt(k); for (int k = 0; probsAccess.next(probsIter); k++) {
double pp = probsAccess.getDouble(probsIter) / sum;
if (!Double.isFinite(pp) || pp < 0. || pp > 1.) { if (!Double.isFinite(pp) || pp < 0. || pp > 1.) {
rN[rnStartIdx + k] = RRuntime.INT_NA; rN[rnStartIdx + k] = RRuntime.INT_NA;
return false; return false;
...@@ -74,9 +77,10 @@ public final class RMultinom { ...@@ -74,9 +77,10 @@ public final class RMultinom {
} }
/* Generate the first K-1 obs. via binomials */ /* Generate the first K-1 obs. via binomials */
for (int k = 0; k < maxK - 1; k++) { probsAccess.reset(probsIter);
for (int k = 0; probsAccess.next(probsIter) && k < maxK - 1; k++) {
/* (p_tot, n) are for "remaining binomial" */ /* (p_tot, n) are for "remaining binomial" */
/* LDOUBLE */double probK = prob.getDataAt(k); /* LDOUBLE */double probK = probsAccess.getDouble(probsIter) / sum;
if (probK != 0.) { if (probK != 0.) {
double pp = probK / pTot; double pp = probK / pTot;
// System.out.printf("[%d] %.17f\n", k + 1, pp); // System.out.printf("[%d] %.17f\n", k + 1, pp);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment