From ed3e124590f1c414b1d54799f595ede542da0b57 Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Wed, 15 Nov 2017 17:02:36 +0100 Subject: [PATCH] convert RNG and stats functions to VectorAccess --- .../r/library/stats/RMultinomNode.java | 126 ++++---- .../r/library/stats/RandFunctionsNodes.java | 273 ++++++++++-------- .../r/library/stats/StatsFunctionsNodes.java | 68 +++-- .../foreign/CallAndExternalFunctions.java | 42 +-- .../r/runtime/nmath/RandomFunctions.java | 4 +- .../r/runtime/nmath/distr/RMultinom.java | 16 +- 6 files changed, 286 insertions(+), 243 deletions(-) diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinomNode.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinomNode.java index 56fa2752f5..0fcd70396e 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinomNode.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RMultinomNode.java @@ -27,18 +27,15 @@ import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.nodes.attributes.GetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; -import com.oracle.truffle.r.nodes.function.opt.ReuseNonSharedNode; import com.oracle.truffle.r.nodes.function.opt.UpdateShareableChildValueNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.nodes.ReadAccessor; -import com.oracle.truffle.r.runtime.data.nodes.SetDataAt; -import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandomNumberProvider; import com.oracle.truffle.r.runtime.nmath.distr.RMultinom; import com.oracle.truffle.r.runtime.nmath.distr.Rbinom; @@ -49,8 +46,15 @@ import com.oracle.truffle.r.runtime.rng.RRNG; * Implements the vectorization of {@link RMultinom}. */ public abstract class RMultinomNode extends RExternalBuiltinNode.Arg3 { + private final Rbinom rbinom = new Rbinom(); + private final ValueProfile randGeneratorClassProfile = ValueProfile.createClassProfile(); + private final ConditionProfile hasAttributesProfile = ConditionProfile.createBinaryProfile(); + @Child private UpdateShareableChildValueNode updateSharedAttributeNode = UpdateShareableChildValueNode.create(); + @Child private GetFixedAttributeNode getNamesNode = GetFixedAttributeNode.createNames(); + @Child private SetFixedAttributeNode setDimNamesNode = SetFixedAttributeNode.createDimNames(); + public static RMultinomNode create() { return RMultinomNodeGen.create(); } @@ -68,63 +72,69 @@ public abstract class RMultinomNode extends RExternalBuiltinNode.Arg3 { } @Specialization - protected RIntVector doMultinom(int n, int size, RAbstractDoubleVector probsVec, - @Cached("create()") VectorReadAccess.Double probsAccess, - @Cached("create()") SetDataAt.Double probsSetter, - @Cached("create()") ReuseNonSharedNode reuseNonSharedNode, - @Cached("createClassProfile()") ValueProfile randGeneratorClassProfile, - @Cached("createBinaryProfile()") ConditionProfile hasAttributesProfile, - @Cached("create()") UpdateShareableChildValueNode updateSharedAttributeNode, - @Cached("createNames()") GetFixedAttributeNode getNamesNode, - @Cached("createDimNames()") SetFixedAttributeNode setDimNamesNode) { - RDoubleVector nonSharedProbs = ((RAbstractDoubleVector) reuseNonSharedNode.execute(probsVec)).materialize(); - ReadAccessor.Double probs = new ReadAccessor.Double(nonSharedProbs, probsAccess); - fixupProb(nonSharedProbs, probs, probsSetter); - - RRNG.getRNGState(); - RandomNumberProvider rand = new RandomNumberProvider(randGeneratorClassProfile.profile(RRNG.currentGenerator()), RRNG.currentNormKind()); - int k = nonSharedProbs.getLength(); - int[] result = new int[k * n]; - boolean isComplete = true; - for (int i = 0, ik = 0; i < n; i++, ik += k) { - isComplete &= RMultinom.rmultinom(size, probs, k, result, ik, rand, rbinom); - } - RRNG.putRNGState(); + protected RIntVector doMultinom(int n, int size, RAbstractDoubleVector probs, + @Cached("probs.access()") VectorAccess probsAccess) { + try (SequentialIterator probsIter = probsAccess.access(probs)) { + double sum = 0.0; + while (probsAccess.next(probsIter)) { + double prob = probsAccess.getDouble(probsIter); + if (!Double.isFinite(prob)) { + throw error(NA_IN_PROB_VECTOR); + } + if (prob < 0.0) { + throw error(NEGATIVE_PROBABILITY); + } + sum += prob; + } + if (sum == 0) { + throw error(NO_POSITIVE_PROBABILITIES); + } - // take names from probVec (if any) as row names in the result - RIntVector resultVec = RDataFactory.createIntVector(result, isComplete, new int[]{k, n}); - if (hasAttributesProfile.profile(probsVec.getAttributes() != null)) { - Object probsNames = getNamesNode.execute(probsVec.getAttributes()); - updateSharedAttributeNode.execute(probsVec, probsNames); - Object[] dimnamesData = new Object[]{probsNames, RNull.instance}; - setDimNamesNode.execute(resultVec.getAttributes(), RDataFactory.createList(dimnamesData)); - } - return resultVec; - } + RRNG.getRNGState(); + RandomNumberProvider rand = new RandomNumberProvider(randGeneratorClassProfile.profile(RRNG.currentGenerator()), RRNG.currentNormKind()); + int[] result = new int[probsAccess.getLength(probsIter) * n]; + if (size > 0) { + for (int i = 0, ik = 0; i < n; i++, ik += probsAccess.getLength(probsIter)) { + double currentSum = sum; + int currentSize = size; + /* Generate the first K-1 obs. via binomials */ + probsAccess.reset(probsIter); + for (int k = 0; probsAccess.next(probsIter) && k < probsAccess.getLength(probsIter) - 1; k++) { + /* (p_tot, n) are for "remaining binomial" */ + /* LDOUBLE */double probK = probsAccess.getDouble(probsIter); + if (probK != 0.) { + double pp = probK / currentSum; + int value = (pp < 1.) ? (int) rbinom.execute(currentSize, pp, rand) : currentSize; + /* + * >= 1; > 1 happens because of rounding + */ + result[ik + k] = value; + currentSize -= value; + } else { + result[ik + k] = 0; + } + if (n <= 0) { + /* we have all */ + break; + } + /* i.e. = sum(prob[(k+1):K]) */ + currentSum -= probK; + } - private void fixupProb(RDoubleVector p, ReadAccessor.Double pAccess, SetDataAt.Double pSetter) { - double sum = 0.0; - int npos = 0; - int pLength = p.getLength(); - for (int i = 0; i < pLength; i++) { - double prob = pAccess.getDataAt(i); - if (!Double.isFinite(prob)) { - throw error(NA_IN_PROB_VECTOR); + result[ik + probsAccess.getLength(probsIter) - 1] = currentSize; + } } - if (prob < 0.0) { - throw error(NEGATIVE_PROBABILITY); - } - if (prob > 0.0) { - npos++; - sum += prob; + RRNG.putRNGState(); + + // take names from probVec (if any) as row names in the result + RIntVector resultVec = RDataFactory.createIntVector(result, true, new int[]{probsAccess.getLength(probsIter), n}); + if (hasAttributesProfile.profile(probs.getAttributes() != null)) { + Object probsNames = getNamesNode.execute(probs.getAttributes()); + updateSharedAttributeNode.execute(probs, probsNames); + Object[] dimnamesData = new Object[]{probsNames, RNull.instance}; + setDimNamesNode.execute(resultVec.getAttributes(), RDataFactory.createList(dimnamesData)); } - } - if (npos == 0) { - throw error(NO_POSITIVE_PROBABILITIES); - } - for (int i = 0; i < pLength; i++) { - double prob = pAccess.getDataAt(i); - pSetter.setDataAt(p, pAccess.getStore(), i, prob / sum); + return resultVec; } } } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java index 457b35543d..e400463057 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/RandFunctionsNodes.java @@ -16,10 +16,12 @@ package com.oracle.truffle.r.library.stats; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; -import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS; import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; +import static com.oracle.truffle.r.runtime.RError.Message.INVALID_UNNAMED_ARGUMENTS; import java.util.Arrays; +import java.util.function.Function; +import java.util.function.Supplier; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; @@ -30,19 +32,22 @@ import com.oracle.truffle.r.library.stats.RandFunctionsNodesFactory.ConvertToLen import com.oracle.truffle.r.library.stats.RandFunctionsNodesFactory.RandFunction1NodeGen; import com.oracle.truffle.r.library.stats.RandFunctionsNodesFactory.RandFunction2NodeGen; import com.oracle.truffle.r.library.stats.RandFunctionsNodesFactory.RandFunction3NodeGen; +import com.oracle.truffle.r.library.stats.RandFunctionsNodesFactory.RandFunctionDoubleExecutorNodeGen; +import com.oracle.truffle.r.library.stats.RandFunctionsNodesFactory.RandFunctionExecutorBaseNodeGen; +import com.oracle.truffle.r.library.stats.RandFunctionsNodesFactory.RandFunctionIntExecutorNodeGen; import com.oracle.truffle.r.nodes.builtin.NodeWithArgumentCasts.Casts; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.nodes.unary.CastIntegerNode; import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDouble; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -import com.oracle.truffle.r.runtime.data.nodes.VectorIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction1_Double; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction2_Double; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandFunction3_Double; @@ -100,53 +105,65 @@ public final class RandFunctionsNodes { * {@link RandFunction3_Double}. */ protected abstract static class RandFunctionExecutorBase extends RBaseNode { - static final class RandGenerationNodeData { - final BranchProfile nanResult = BranchProfile.create(); - final BranchProfile nan = BranchProfile.create(); - final VectorLengthProfile resultVectorLengthProfile = VectorLengthProfile.create(); - final LoopConditionProfile loopConditionProfile = LoopConditionProfile.createCountingProfile(); - - public static RandGenerationNodeData create() { - return new RandGenerationNodeData(); - } + + protected final Function<Supplier<? extends RandFunction3_Double>, RandFunctionIterator> iteratorFactory; + protected final Supplier<? extends RandFunction3_Double> functionFactory; + + protected RandFunctionExecutorBase(Function<Supplier<? extends RandFunction3_Double>, RandFunctionIterator> iteratorFactory, Supplier<? extends RandFunction3_Double> functionFactory) { + this.iteratorFactory = iteratorFactory; + this.functionFactory = functionFactory; } + public abstract RAbstractVector execute(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand); + + @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + private final VectorLengthProfile resultVectorLengthProfile = VectorLengthProfile.create(); + @Override - protected RBaseNode getErrorContext() { + protected final RBaseNode getErrorContext() { return RError.SHOW_CALLER; } - public abstract Object execute(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand); - - @Child private ConvertToLength convertToLength = ConvertToLengthNodeGen.create(); + protected final RandFunctionIterator createIterator() { + return iteratorFactory.apply(functionFactory); + } @Specialization(guards = {"randCached.isSame(rand)"}) - protected final Object evaluateWithCached(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, + protected final RAbstractVector evaluateWithCached(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, @SuppressWarnings("unused") RandomNumberProvider rand, @Cached("rand") RandomNumberProvider randCached, - @Cached("create()") RandGenerationNodeData nodeData) { - return evaluateWrapper(lengthVec, a, b, c, randCached, nodeData); + @Cached("createIterator()") RandFunctionIterator iterator) { + int length = resultVectorLengthProfile.profile(convertToLength.execute(lengthVec)); + RBaseNode.reportWork(this, length); + return iterator.execute(length, a, b, c, randCached); } @Specialization(replaces = "evaluateWithCached") - protected final Object evaluateFallback(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand, - @Cached("create()") RandGenerationNodeData nodeData) { - return evaluateWrapper(lengthVec, a, b, c, rand, nodeData); + protected final RAbstractVector evaluateFallback(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand, + @Cached("createIterator()") RandFunctionIterator iterator) { + int length = resultVectorLengthProfile.profile(convertToLength.execute(lengthVec)); + RBaseNode.reportWork(this, length); + return iterator.execute(length, a, b, c, rand); } + } - private Object evaluateWrapper(RAbstractVector lengthVec, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand, - RandGenerationNodeData nodeData) { - int length = nodeData.resultVectorLengthProfile.profile(convertToLength.execute(lengthVec)); - RBaseNode.reportWork(this, length); - return evaluate(length, a, b, c, nodeData, rand); + protected abstract static class RandFunctionIterator extends RBaseNode { + + protected final Supplier<? extends RandFunction3_Double> functionFactory; + protected final BranchProfile nanResult = BranchProfile.create(); + protected final BranchProfile nan = BranchProfile.create(); + protected final LoopConditionProfile loopConditionProfile = LoopConditionProfile.createCountingProfile(); + + protected RandFunctionIterator(Supplier<? extends RandFunction3_Double> functionFactory) { + this.functionFactory = functionFactory; } - @SuppressWarnings("unused") - Object evaluate(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandGenerationNodeData nodeData, RandomNumberProvider randProvider) { - // DSL generates code for this class too, with abstract method it would not compile - throw RInternalError.shouldNotReachHere("must be overridden"); + protected final RandFunction3_Double createFunction() { + return functionFactory.get(); } + public abstract RAbstractVector execute(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider rand); + static void putRNGState() { // Note: we call putRNGState only if we actually changed the state, i.e. called random // number generation. We do not need to getRNGState() because the parent wrapper node @@ -159,105 +176,107 @@ public final class RandFunctionsNodes { } } - protected abstract static class RandFunctionIntExecutorNode extends RandFunctionExecutorBase { - @Child private RandFunction3_Double function; - @Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround(); - @Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround(); - @Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround(); + protected abstract static class RandFunctionIntExecutorNode extends RandFunctionIterator { - protected RandFunctionIntExecutorNode(RandFunction3_Double function) { - this.function = function; + protected RandFunctionIntExecutorNode(Supplier<? extends RandFunction3_Double> functionFactory) { + super(functionFactory); } - @Override - protected RAbstractIntVector evaluate(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandGenerationNodeData nodeData, - RandomNumberProvider randProvider) { - int aLength = a.getLength(); - int bLength = b.getLength(); - int cLength = c.getLength(); - if (aLength == 0 || bLength == 0 || cLength == 0) { - nodeData.nanResult.enter(); - showNAWarning(); - int[] nansResult = new int[length]; - Arrays.fill(nansResult, RRuntime.INT_NA); - return RDataFactory.createIntVector(nansResult, false); - } + @Specialization(guards = {"aAccess.supports(a)", "bAccess.supports(b)", "cAccess.supports(c)"}) + protected RAbstractIntVector cached(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider randProvider, + @Cached("createFunction()") RandFunction3_Double function, + @Cached("a.access()") VectorAccess aAccess, + @Cached("b.access()") VectorAccess bAccess, + @Cached("c.access()") VectorAccess cAccess) { + try (SequentialIterator aIter = aAccess.access(a); SequentialIterator bIter = bAccess.access(b); SequentialIterator cIter = cAccess.access(c)) { + if (aAccess.getLength(aIter) == 0 || bAccess.getLength(bIter) == 0 || cAccess.getLength(cIter) == 0) { + nanResult.enter(); + showNAWarning(); + int[] nansResult = new int[length]; + Arrays.fill(nansResult, RRuntime.INT_NA); + return RDataFactory.createIntVector(nansResult, false); + } - Object aIt = aIterator.init(a); - Object bIt = bIterator.init(b); - Object cIt = cIterator.init(c); - boolean nans = false; - int[] result = new int[length]; - nodeData.loopConditionProfile.profileCounted(length); - for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { - double aValue = aIterator.next(a, aIt); - double bValue = bIterator.next(b, bIt); - double cValue = cIterator.next(c, cIt); - double value = function.execute(aValue, bValue, cValue, randProvider); - if (Double.isNaN(value) || value <= Integer.MIN_VALUE || value > Integer.MAX_VALUE) { - nodeData.nan.enter(); - nans = true; - result[i] = RRuntime.INT_NA; - } else { - result[i] = (int) value; + boolean nans = false; + int[] result = new int[length]; + loopConditionProfile.profileCounted(length); + for (int i = 0; loopConditionProfile.inject(i < length); i++) { + aAccess.nextWithWrap(aIter); + bAccess.nextWithWrap(bIter); + cAccess.nextWithWrap(cIter); + double value = function.execute(aAccess.getDouble(aIter), bAccess.getDouble(bIter), cAccess.getDouble(cIter), randProvider); + if (Double.isNaN(value) || value <= Integer.MIN_VALUE || value > Integer.MAX_VALUE) { + nan.enter(); + nans = true; + result[i] = RRuntime.INT_NA; + } else { + result[i] = (int) value; + } } + putRNGState(); + if (nans) { + showNAWarning(); + } + return RDataFactory.createIntVector(result, !nans); } - putRNGState(); - if (nans) { - showNAWarning(); - } - return RDataFactory.createIntVector(result, !nans); + } + + @Specialization(replaces = "cached") + protected RAbstractIntVector generic(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider randProvider, + @Cached("createFunction()") RandFunction3_Double function) { + return cached(length, a, b, c, randProvider, function, a.slowPathAccess(), b.slowPathAccess(), c.slowPathAccess()); } } - protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionExecutorBase { - @Child private RandFunction3_Double function; - @Child private VectorIterator.Double aIterator = VectorIterator.Double.createWrapAround(); - @Child private VectorIterator.Double bIterator = VectorIterator.Double.createWrapAround(); - @Child private VectorIterator.Double cIterator = VectorIterator.Double.createWrapAround(); + protected abstract static class RandFunctionDoubleExecutorNode extends RandFunctionIterator { - protected RandFunctionDoubleExecutorNode(RandFunction3_Double function) { - this.function = function; + protected RandFunctionDoubleExecutorNode(Supplier<? extends RandFunction3_Double> functionFactory) { + super(functionFactory); } - @Override - protected RAbstractDoubleVector evaluate(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandGenerationNodeData nodeData, - RandomNumberProvider randProvider) { - int aLength = a.getLength(); - int bLength = b.getLength(); - int cLength = c.getLength(); - if (aLength == 0 || bLength == 0 || cLength == 0) { - nodeData.nanResult.enter(); - showNAWarning(); - double[] nansResult = new double[length]; - Arrays.fill(nansResult, RRuntime.DOUBLE_NA); - return RDataFactory.createDoubleVector(nansResult, false); - } + @Specialization(guards = {"aAccess.supports(a)", "bAccess.supports(b)", "cAccess.supports(c)"}) + protected RAbstractDoubleVector cached(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider randProvider, + @Cached("createFunction()") RandFunction3_Double function, + @Cached("a.access()") VectorAccess aAccess, + @Cached("b.access()") VectorAccess bAccess, + @Cached("c.access()") VectorAccess cAccess) { + try (SequentialIterator aIter = aAccess.access(a); SequentialIterator bIter = bAccess.access(b); SequentialIterator cIter = cAccess.access(c)) { + if (aAccess.getLength(aIter) == 0 || bAccess.getLength(bIter) == 0 || cAccess.getLength(cIter) == 0) { + nanResult.enter(); + showNAWarning(); + double[] nansResult = new double[length]; + Arrays.fill(nansResult, RRuntime.DOUBLE_NA); + return RDataFactory.createDoubleVector(nansResult, false); + } - Object aIt = aIterator.init(a); - Object bIt = bIterator.init(b); - Object cIt = cIterator.init(c); - boolean nans = false; - double[] result; - result = new double[length]; - nodeData.loopConditionProfile.profileCounted(length); - for (int i = 0; nodeData.loopConditionProfile.inject(i < length); i++) { - double aValue = aIterator.next(a, aIt); - double bValue = bIterator.next(b, bIt); - double cValue = cIterator.next(c, cIt); - double value = function.execute(aValue, bValue, cValue, randProvider); - if (Double.isNaN(value) || RRuntime.isNA(value)) { - nodeData.nan.enter(); - nans = true; + boolean nans = false; + double[] result = new double[length]; + loopConditionProfile.profileCounted(length); + for (int i = 0; loopConditionProfile.inject(i < length); i++) { + aAccess.nextWithWrap(aIter); + bAccess.nextWithWrap(bIter); + cAccess.nextWithWrap(cIter); + double value = function.execute(aAccess.getDouble(aIter), bAccess.getDouble(bIter), cAccess.getDouble(cIter), randProvider); + if (Double.isNaN(value) || RRuntime.isNA(value)) { + nan.enter(); + nans = true; + } + result[i] = value; } - result[i] = value; - } - putRNGState(); - if (nans) { - showNAWarning(); + putRNGState(); + if (nans) { + showNAWarning(); + } + return RDataFactory.createDoubleVector(result, !nans); } - return RDataFactory.createDoubleVector(result, !nans); } + + @Specialization(replaces = "cached") + protected RAbstractDoubleVector generic(int length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, RandomNumberProvider randProvider, + @Cached("createFunction()") RandFunction3_Double function) { + return cached(length, a, b, c, randProvider, function, a.slowPathAccess(), b.slowPathAccess(), c.slowPathAccess()); + } + } public abstract static class RandFunction3Node extends RExternalBuiltinNode.Arg4 { @@ -267,13 +286,13 @@ public final class RandFunctionsNodes { this.inner = inner; } - public static RandFunction3Node createInt(RandFunction3_Double function) { - return RandFunction3NodeGen.create(RandFunctionsNodesFactory.RandFunctionIntExecutorNodeGen.create(function)); + public static RandFunction3Node createInt(Supplier<RandFunction3_Double> function) { + return RandFunction3NodeGen.create(RandFunctionExecutorBaseNodeGen.create(RandFunctionIntExecutorNodeGen::create, function)); } // Note: for completeness of the API - public static RandFunction3Node createDouble(RandFunction3_Double function) { - return RandFunction3NodeGen.create(RandFunctionsNodesFactory.RandFunctionDoubleExecutorNodeGen.create(function)); + public static RandFunction3Node createDouble(Supplier<RandFunction3_Double> function) { + return RandFunction3NodeGen.create(RandFunctionExecutorBaseNodeGen.create(RandFunctionDoubleExecutorNodeGen::create, function)); } static { @@ -285,7 +304,7 @@ public final class RandFunctionsNodes { } @Specialization - protected Object evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c) { + protected RAbstractVector evaluate(RAbstractVector length, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c) { RRNG.getRNGState(); return inner.execute(length, a, b, c, RandomNumberProvider.fromCurrentRNG()); } @@ -298,12 +317,12 @@ public final class RandFunctionsNodes { this.inner = inner; } - public static RandFunction2Node createInt(RandFunction2_Double function) { - return RandFunction2NodeGen.create(RandFunctionsNodesFactory.RandFunctionIntExecutorNodeGen.create(function)); + public static RandFunction2Node createInt(Supplier<RandFunction2_Double> function) { + return RandFunction2NodeGen.create(RandFunctionExecutorBaseNodeGen.create(RandFunctionIntExecutorNodeGen::create, function)); } - public static RandFunction2Node createDouble(RandFunction2_Double function) { - return RandFunction2NodeGen.create(RandFunctionsNodesFactory.RandFunctionDoubleExecutorNodeGen.create(function)); + public static RandFunction2Node createDouble(Supplier<RandFunction2_Double> function) { + return RandFunction2NodeGen.create(RandFunctionExecutorBaseNodeGen.create(RandFunctionDoubleExecutorNodeGen::create, function)); } static { @@ -327,12 +346,12 @@ public final class RandFunctionsNodes { this.inner = inner; } - public static RandFunction1Node createInt(RandFunction1_Double function) { - return RandFunction1NodeGen.create(RandFunctionsNodesFactory.RandFunctionIntExecutorNodeGen.create(function)); + public static RandFunction1Node createInt(Supplier<RandFunction1_Double> function) { + return RandFunction1NodeGen.create(RandFunctionExecutorBaseNodeGen.create(RandFunctionIntExecutorNodeGen::create, function)); } - public static RandFunction1Node createDouble(RandFunction1_Double function) { - return RandFunction1NodeGen.create(RandFunctionsNodesFactory.RandFunctionDoubleExecutorNodeGen.create(function)); + public static RandFunction1Node createDouble(Supplier<RandFunction1_Double> function) { + return RandFunction1NodeGen.create(RandFunctionExecutorBaseNodeGen.create(RandFunctionDoubleExecutorNodeGen::create, function)); } static { diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctionsNodes.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctionsNodes.java index ebba5274ac..51564fd6f8 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctionsNodes.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctionsNodes.java @@ -46,8 +46,9 @@ import com.oracle.truffle.r.runtime.data.RDouble; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.nodes.ReadAccessor; -import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator; import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function2_1; import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function2_2; import com.oracle.truffle.r.runtime.nmath.MathFunctions.Function3_1; @@ -374,10 +375,11 @@ public final class StatsFunctionsNodes { casts.arg(6).asDoubleVector().findFirst(); } - @Specialization + @Specialization(guards = {"xAccess.supports(x)", "yAccess.supports(y)", "vAccess.supports(v)"}) protected RDoubleVector approx(RAbstractDoubleVector x, RAbstractDoubleVector y, RAbstractDoubleVector v, int method, double yl, double yr, double f, - @Cached("create()") VectorReadAccess.Double xAccess, - @Cached("create()") VectorReadAccess.Double yAccess) { + @Cached("x.access()") VectorAccess xAccess, + @Cached("y.access()") VectorAccess yAccess, + @Cached("v.access()") VectorAccess vAccess) { int nx = x.getLength(); int nout = v.getLength(); double[] yout = new double[nout]; @@ -390,14 +392,21 @@ public final class StatsFunctionsNodes { apprMeth.yhigh = yr; naCheck.enable(true); - ReadAccessor.Double xAccessor = new ReadAccessor.Double(x, xAccess); - ReadAccessor.Double yAccessor = new ReadAccessor.Double(y, yAccess); - for (int i = 0; i < nout; i++) { - double xouti = v.getDataAt(i); - yout[i] = RRuntime.isNAorNaN(xouti) ? xouti : approx1(xouti, xAccessor, yAccessor, nx, apprMeth); - naCheck.check(yout[i]); + try (RandomIterator xIter = xAccess.randomAccess(x); RandomIterator yIter = yAccess.randomAccess(y); SequentialIterator vIter = vAccess.access(v)) { + int i = 0; + while (vAccess.next(vIter)) { + double xouti = vAccess.getDouble(vIter); + yout[i] = RRuntime.isNAorNaN(xouti) ? xouti : approx1(xouti, xAccess, xIter, yAccess, yIter, nx, apprMeth); + naCheck.check(yout[i]); + i++; + } + return RDataFactory.createDoubleVector(yout, naCheck.neverSeenNA()); } - return RDataFactory.createDoubleVector(yout, naCheck.neverSeenNA()); + } + + @Specialization(replaces = "approx") + protected RDoubleVector approxGeneric(RAbstractDoubleVector x, RAbstractDoubleVector y, RAbstractDoubleVector v, int method, double yl, double yr, double f) { + return approx(x, y, v, method, yl, yr, f, x.slowPathAccess(), y.slowPathAccess(), v.slowPathAccess()); } private static class ApprMeth { @@ -408,32 +417,29 @@ public final class StatsFunctionsNodes { int kind; } - private static double approx1(double v, ReadAccessor.Double x, ReadAccessor.Double y, int n, + private static double approx1(double v, VectorAccess xAccess, RandomIterator xIter, VectorAccess yAccess, RandomIterator yIter, int n, ApprMeth apprMeth) { /* Approximate y(v), given (x,y)[i], i = 0,..,n-1 */ - int i; - int j; - int ij; if (n == 0) { return RRuntime.DOUBLE_NA; } - i = 0; - j = n - 1; - + int i = 0; + int j = n - 1; /* handle out-of-domain points */ - if (v < x.getDataAt(i)) { + if (v < xAccess.getDouble(xIter, i)) { return apprMeth.ylow; } - if (v > x.getDataAt(j)) { + if (v > xAccess.getDouble(xIter, j)) { return apprMeth.yhigh; } /* find the correct interval by bisection */ while (i < j - 1) { /* x.getDataAt(i) <= v <= x.getDataAt(j) */ - ij = (i + j) / 2; /* i+1 <= ij <= j-1 */ - if (v < x.getDataAt(ij)) { + int ij = (i + j) / 2; + /* i+1 <= ij <= j-1 */ + if (v < xAccess.getDouble(xIter, ij)) { j = ij; } else { i = ij; @@ -444,18 +450,22 @@ public final class StatsFunctionsNodes { /* interpolation */ - if (v == x.getDataAt(j)) { - return y.getDataAt(j); + double xJ = xAccess.getDouble(xIter, j); + double yJ = yAccess.getDouble(yIter, j); + if (v == xJ) { + return yJ; } - if (v == x.getDataAt(i)) { - return y.getDataAt(i); + double xI = xAccess.getDouble(xIter, i); + double yI = yAccess.getDouble(yIter, i); + if (v == xI) { + return yI; } /* impossible: if(x.getDataAt(j) == x.getDataAt(i)) return y.getDataAt(i); */ if (apprMeth.kind == 1) { /* linear */ - return y.getDataAt(i) + (y.getDataAt(j) - y.getDataAt(i)) * ((v - x.getDataAt(i)) / (x.getDataAt(j) - x.getDataAt(i))); + return yI + (yJ - yI) * ((v - xI) / (xJ - xI)); } else { /* 2 : constant */ - return (apprMeth.f1 != 0.0 ? y.getDataAt(i) * apprMeth.f1 : 0.0) + (apprMeth.f2 != 0.0 ? y.getDataAt(j) * apprMeth.f2 : 0.0); + return (apprMeth.f1 != 0.0 ? yI * apprMeth.f1 : 0.0) + (apprMeth.f2 != 0.0 ? yJ * apprMeth.f2 : 0.0); } }/* approx1() */ diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java index fa7bd09346..97bdda9d9a 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java @@ -324,43 +324,43 @@ public class CallAndExternalFunctions { case "qnorm": return StatsFunctionsNodes.Function3_2Node.create(new Qnorm()); case "rnorm": - return RandFunction2Node.createDouble(new Rnorm()); + return RandFunction2Node.createDouble(Rnorm::new); case "runif": - return RandFunction2Node.createDouble(new Runif()); + return RandFunction2Node.createDouble(Runif::new); case "rbeta": - return RandFunction2Node.createDouble(new RBeta()); + return RandFunction2Node.createDouble(RBeta::new); case "rgamma": - return RandFunction2Node.createDouble(new RGamma()); + return RandFunction2Node.createDouble(RGamma::new); case "rcauchy": - return RandFunction2Node.createDouble(new RCauchy()); + return RandFunction2Node.createDouble(RCauchy::new); case "rf": - return RandFunction2Node.createDouble(new Rf()); + return RandFunction2Node.createDouble(Rf::new); case "rlogis": - return RandFunction2Node.createDouble(new RLogis()); + return RandFunction2Node.createDouble(RLogis::new); case "rweibull": - return RandFunction2Node.createDouble(new RWeibull()); + return RandFunction2Node.createDouble(RWeibull::new); case "rnchisq": - return RandFunction2Node.createDouble(new RNchisq()); + return RandFunction2Node.createDouble(RNchisq::new); case "rnbinom_mu": - return RandFunction2Node.createDouble(new RNBinomMu()); + return RandFunction2Node.createDouble(RNBinomMu::new); case "rwilcox": - return RandFunction2Node.createInt(new RWilcox()); + return RandFunction2Node.createInt(RWilcox::new); case "rchisq": - return RandFunction1Node.createDouble(new RChisq()); + return RandFunction1Node.createDouble(RChisq::new); case "rexp": - return RandFunction1Node.createDouble(new RExp()); + return RandFunction1Node.createDouble(RExp::new); case "rgeom": - return RandFunction1Node.createInt(new RGeom()); + return RandFunction1Node.createInt(RGeom::new); case "rpois": - return RandFunction1Node.createInt(new RPois()); + return RandFunction1Node.createInt(RPois::new); case "rnbinom": - return RandFunction2Node.createInt(new RNBinomFunc()); + return RandFunction2Node.createInt(RNBinomFunc::new); case "rt": - return RandFunction1Node.createDouble(new Rt()); + return RandFunction1Node.createDouble(Rt::new); case "rsignrank": - return RandFunction1Node.createInt(new RSignrank()); + return RandFunction1Node.createInt(RSignrank::new); case "rhyper": - return RandFunction3Node.createInt(new RHyper()); + return RandFunction3Node.createInt(RHyper::new); case "phyper": return StatsFunctionsNodes.Function4_2Node.create(new PHyper()); case "dhyper": @@ -400,7 +400,7 @@ public class CallAndExternalFunctions { case "dweibull": return StatsFunctionsNodes.Function3_1Node.create(new DWeibull()); case "rbinom": - return RandFunction2Node.createInt(new Rbinom()); + return RandFunction2Node.createInt(Rbinom::new); case "pbinom": return StatsFunctionsNodes.Function3_2Node.create(new Pbinom()); case "pbeta": @@ -458,7 +458,7 @@ public class CallAndExternalFunctions { case "dt": return StatsFunctionsNodes.Function2_1Node.create(new Dt()); case "rlnorm": - return RandFunction2Node.createDouble(new LogNormal.RLNorm()); + return RandFunction2Node.createDouble(LogNormal.RLNorm::new); case "dlnorm": return StatsFunctionsNodes.Function3_1Node.create(new DLNorm()); case "qlnorm": diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/RandomFunctions.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/RandomFunctions.java index 736cf4b3bd..3c4f0e2224 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/RandomFunctions.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/RandomFunctions.java @@ -47,11 +47,11 @@ public class RandomFunctions { } } - public abstract static class RandFunction1_Double extends RandFunction2_Double { + public abstract static class RandFunction1_Double extends RandFunction3_Double { public abstract double execute(double a, RandomNumberProvider rand); @Override - public final double execute(double a, double b, RandomNumberProvider rand) { + public final double execute(double a, double b, double c, RandomNumberProvider rand) { return execute(a, rand); } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/distr/RMultinom.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/distr/RMultinom.java index 1acc54576b..822b14b94a 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/distr/RMultinom.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nmath/distr/RMultinom.java @@ -18,7 +18,8 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RRuntime; -import com.oracle.truffle.r.runtime.data.nodes.ReadAccessor; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.nmath.RandomFunctions.RandomNumberProvider; public final class RMultinom { @@ -32,7 +33,7 @@ public final class RMultinom { * prob[j]) , sum_j rN[j] == n, sum_j prob[j] == 1. */ @TruffleBoundary - public static boolean rmultinom(int nIn, ReadAccessor.Double prob, int maxK, int[] rN, int rnStartIdx, RandomNumberProvider rand, Rbinom rbinom) { + public static boolean rmultinom(int nIn, SequentialIterator probsIter, VectorAccess probsAccess, double sum, int[] rN, int rnStartIdx, RandomNumberProvider rand, Rbinom rbinom) { /* * This calculation is sensitive to exact values, so we try to ensure that the calculations * are as accurate as possible so different platforms are more likely to give the same @@ -40,6 +41,7 @@ public final class RMultinom { */ int n = nIn; + int maxK = probsAccess.getLength(probsIter); if (RRuntime.isNA(maxK) || maxK < 1 || RRuntime.isNA(n) || n < 0) { if (rN.length > rnStartIdx) { rN[rnStartIdx] = RRuntime.INT_NA; @@ -52,8 +54,9 @@ public final class RMultinom { * shorter and drop that check ! */ /* LDOUBLE */double pTot = 0.; - for (int k = 0; k < maxK; k++) { - double pp = prob.getDataAt(k); + probsAccess.reset(probsIter); + for (int k = 0; probsAccess.next(probsIter); k++) { + double pp = probsAccess.getDouble(probsIter) / sum; if (!Double.isFinite(pp) || pp < 0. || pp > 1.) { rN[rnStartIdx + k] = RRuntime.INT_NA; return false; @@ -74,9 +77,10 @@ public final class RMultinom { } /* Generate the first K-1 obs. via binomials */ - for (int k = 0; k < maxK - 1; k++) { + probsAccess.reset(probsIter); + for (int k = 0; probsAccess.next(probsIter) && k < maxK - 1; k++) { /* (p_tot, n) are for "remaining binomial" */ - /* LDOUBLE */double probK = prob.getDataAt(k); + /* LDOUBLE */double probK = probsAccess.getDouble(probsIter) / sum; if (probK != 0.) { double pp = probK / pTot; // System.out.printf("[%d] %.17f\n", k + 1, pp); -- GitLab