diff --git a/com.oracle.truffle.r.native/fficall/src/jni/userrng.c b/com.oracle.truffle.r.native/fficall/src/jni/userrng.c new file mode 100644 index 0000000000000000000000000000000000000000..35d4e2bd9128da1be5cba8a53251026f43945d8d --- /dev/null +++ b/com.oracle.truffle.r.native/fficall/src/jni/userrng.c @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +#include <rffiutils.h> + +typedef void (*call_init)(int seed); +typedef double* (*call_rand)(void); +typedef int* (*call_nSeed)(void); +typedef int* (*call_seeds)(void); + +JNIEXPORT void JNICALL +Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1UserRng_init(JNIEnv *env, jclass c, jlong address, jint seed) { + call_init f = (call_init) address; + f(seed); +} + +JNIEXPORT double JNICALL +Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1UserRng_rand(JNIEnv *env, jclass c, jlong address) { + call_rand f = (call_rand) address; + double* dp = f(); + return *dp; +} + +JNIEXPORT jint JNICALL +Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1UserRng_nSeed(JNIEnv *env, jclass c, jlong address) { + call_nSeed f = (call_nSeed) address; + int *pn = f(); + return *pn; +} + +JNIEXPORT void JNICALL +Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1UserRng_seeds(JNIEnv *env, jclass c, jlong address, jintArray seedsArray) { + call_seeds f = (call_seeds) address; + int *pseeds = f(); + int seedslen = (*env)->GetArrayLength(env, seedsArray); + int *data = (*env)->GetIntArrayElements(env, seedsArray, NULL); + for (int i = 0; i < seedslen; i++) { + data[i] = pseeds[i]; + } + (*env)->ReleaseIntArrayElements(env, seedsArray, data, 0); +} diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java index e0f1161c1d5415edda8156fef193ab9269f7bafa..611cf68d898d0ed50f0630d15a1059eee1cf0174 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java @@ -36,35 +36,43 @@ import com.oracle.truffle.r.runtime.nodes.RNode; public abstract class IntersectFastPath extends RFastPathNode { + protected static final int TYPE_LIMIT = 2; + private static final int[] EMPTY_INT_ARRAY = new int[0]; - @Specialization(guards = {"x.getLength() > 0", "y.getLength() > 0"}) - protected RAbstractIntVector intersect(RAbstractIntVector x, RAbstractIntVector y, // - @Cached("createBinaryProfile()") ConditionProfile isXSortedProfile, // - @Cached("createBinaryProfile()") ConditionProfile isYSortedProfile, // + @Specialization(limit = "TYPE_LIMIT", guards = {"x.getLength() > 0", "y.getLength() > 0", "x.getClass() == xClass", "y.getClass() == yClass"}) + protected RAbstractIntVector intersect(RAbstractIntVector x, RAbstractIntVector y, + @Cached("x.getClass()") Class<? extends RAbstractIntVector> xClass, + @Cached("y.getClass()") Class<? extends RAbstractIntVector> yClass, + @Cached("createBinaryProfile()") ConditionProfile isXSortedProfile, + @Cached("createBinaryProfile()") ConditionProfile isYSortedProfile, @Cached("createBinaryProfile()") ConditionProfile resultLengthMatchProfile) { - int xLength = x.getLength(); - int yLength = y.getLength(); + // apply the type profiles: + RAbstractIntVector profiledX = xClass.cast(x); + RAbstractIntVector profiledY = yClass.cast(y); + + int xLength = profiledX.getLength(); + int yLength = profiledY.getLength(); RNode.reportWork(this, xLength + yLength); int count = 0; int[] result = EMPTY_INT_ARRAY; int maxResultLength = Math.min(xLength, yLength); - if (isXSortedProfile.profile(isSorted(x))) { + if (isXSortedProfile.profile(isSorted(profiledX))) { RAbstractIntVector tempY; - if (!isYSortedProfile.profile(isSorted(y))) { + if (isYSortedProfile.profile(isSorted(profiledY))) { + tempY = profiledY; + } else { int[] temp = new int[yLength]; for (int i = 0; i < yLength; i++) { - temp[i] = y.getDataAt(i); + temp[i] = profiledY.getDataAt(i); } sort(temp); - tempY = RDataFactory.createIntVector(temp, y.isComplete()); - } else { - tempY = y; + tempY = RDataFactory.createIntVector(temp, profiledY.isComplete()); } int xPos = 0; int yPos = 0; - int xValue = x.getDataAt(xPos); + int xValue = profiledX.getDataAt(xPos); int yValue = tempY.getDataAt(yPos); while (true) { if (xValue == yValue) { @@ -77,7 +85,7 @@ public abstract class IntersectFastPath extends RFastPathNode { if (xPos >= xLength - 1) { break; } - int nextValue = x.getDataAt(xPos + 1); + int nextValue = profiledX.getDataAt(xPos + 1); if (xValue != nextValue) { break; } @@ -87,13 +95,13 @@ public abstract class IntersectFastPath extends RFastPathNode { if (++xPos >= xLength || ++yPos >= yLength) { break; } - xValue = x.getDataAt(xPos); + xValue = profiledX.getDataAt(xPos); yValue = tempY.getDataAt(yPos); } else if (xValue < yValue) { if (++xPos >= xLength) { break; } - xValue = x.getDataAt(xPos); + xValue = profiledX.getDataAt(xPos); } else { if (++yPos >= yLength) { break; @@ -105,12 +113,12 @@ public abstract class IntersectFastPath extends RFastPathNode { int[] temp = new int[yLength]; boolean[] used = new boolean[yLength]; for (int i = 0; i < yLength; i++) { - temp[i] = y.getDataAt(i); + temp[i] = profiledY.getDataAt(i); } sort(temp); for (int i = 0; i < xLength; i++) { - int value = x.getDataAt(i); + int value = profiledX.getDataAt(i); int pos = Arrays.binarySearch(temp, value); if (pos >= 0 && !used[pos]) { used[pos] = true; @@ -121,7 +129,7 @@ public abstract class IntersectFastPath extends RFastPathNode { } } } - return RDataFactory.createIntVector(resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count), x.isComplete() | y.isComplete()); + return RDataFactory.createIntVector(resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count), profiledX.isComplete() | profiledY.isComplete()); } private static boolean isSorted(RAbstractIntVector vector) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentStatePush.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentStatePush.java index a45022312ca4653b6c161b8c0734e45354ef3108..f600dd260b59b01622b64f07457435d412e2974c 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentStatePush.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/ArgumentStatePush.java @@ -65,49 +65,51 @@ public abstract class ArgumentStatePush extends Node { public void transitionState(VirtualFrame frame, RShareable shareable) { if (isRefCountUpdateable.profile(!shareable.isSharedPermanent())) { shareable.incRefCount(); - } - if (!FastROptions.RefCountIncrementOnly.getBooleanValue()) { - if (mask == 0) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - if (shareable instanceof RAbstractContainer) { - if (shareable instanceof RLanguage || ((RAbstractContainer) shareable).getLength() < REF_COUNT_SIZE_THRESHOLD) { - // don't decrement ref count for small objects or language objects- this is - // pretty conservative and can be further finessed + if (!FastROptions.RefCountIncrementOnly.getBooleanValue()) { + if (mask == 0) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + if (shareable instanceof RAbstractContainer) { + if (shareable instanceof RLanguage || ((RAbstractContainer) shareable).getLength() < REF_COUNT_SIZE_THRESHOLD) { + // don't decrement ref count for small objects or language objects- this + // is + // pretty conservative and can be further finessed + mask = -1; + return; + } + } + RFunction fun = RArguments.getFunction(frame); + if (fun == null) { mask = -1; return; } + Object root = fun.getRootNode(); + if (!(root instanceof FunctionDefinitionNode)) { + // root is RBuiltinRootNode + mask = -1; + return; + } + FunctionDefinitionNode fdn = (FunctionDefinitionNode) root; + PostProcessArgumentsNode postProcessNode = fdn.getArgPostProcess(); + if (postProcessNode == null) { + // arguments to this function are not to be reference counted + mask = -1; + return; + } + // this is needed for when FunctionDefinitionNode is split by the Truffle + // runtime + postProcessNode = postProcessNode.getActualNode(); + if (index >= Math.min(postProcessNode.getLength(), MAX_COUNTED_ARGS)) { + mask = -1; + return; + } + mask = 1 << index; + int transArgsBitSet = postProcessNode.transArgsBitSet; + postProcessNode.transArgsBitSet = transArgsBitSet | mask; + writeArgNode = insert(WriteLocalFrameVariableNode.createForRefCount(Integer.valueOf(mask))); } - RFunction fun = RArguments.getFunction(frame); - if (fun == null) { - mask = -1; - return; - } - Object root = fun.getRootNode(); - if (!(root instanceof FunctionDefinitionNode)) { - // root is RBuiltinRootNode - mask = -1; - return; - } - FunctionDefinitionNode fdn = (FunctionDefinitionNode) root; - PostProcessArgumentsNode postProcessNode = fdn.getArgPostProcess(); - if (postProcessNode == null) { - // arguments to this function are not to be reference counted - mask = -1; - return; - } - // this is needed for when FunctionDefinitionNode is split by the Truffle runtime - postProcessNode = postProcessNode.getActualNode(); - if (index >= Math.min(postProcessNode.getLength(), MAX_COUNTED_ARGS)) { - mask = -1; - return; + if (mask != -1) { + writeArgNode.execute(frame, shareable); } - mask = 1 << index; - int transArgsBitSet = postProcessNode.transArgsBitSet; - postProcessNode.transArgsBitSet = transArgsBitSet | mask; - writeArgNode = insert(WriteLocalFrameVariableNode.createForRefCount(Integer.valueOf(mask))); - } - if (mask != -1) { - writeArgNode.execute(frame, shareable); } } } @@ -135,8 +137,7 @@ public abstract class ArgumentStatePush extends Node { // this is expected to be used in rare cases where no RNode is easily available if (o instanceof RShareable) { RShareable shareable = (RShareable) o; - // it's never decremented so no point in incrementing past shared state - if (!shareable.isShared()) { + if (!shareable.isSharedPermanent()) { shareable.incRefCount(); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java index b26a5aca8397d77862132affd87663c5c274bbe6..52c81b0097da24334ee814cd0ebeadc4c0875a90 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/RCallNode.java @@ -34,7 +34,6 @@ import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.api.frame.FrameSlot; import com.oracle.truffle.api.frame.FrameSlotTypeException; import com.oracle.truffle.api.frame.MaterializedFrame; import com.oracle.truffle.api.frame.VirtualFrame; @@ -45,7 +44,6 @@ import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.NodeCost; import com.oracle.truffle.api.nodes.NodeInfo; -import com.oracle.truffle.api.nodes.NodeUtil; import com.oracle.truffle.api.nodes.RootNode; import com.oracle.truffle.api.nodes.UnexpectedResultException; import com.oracle.truffle.api.profiles.BranchProfile; @@ -312,7 +310,7 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS } protected RNode createDispatchArgument(int index) { - return new ForcePromiseNode(NodeUtil.cloneNode(arguments[index].asRNode())); + return new ForcePromiseNode(RASTUtils.cloneNode(arguments[index].asRNode())); } /** @@ -339,7 +337,7 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS @Cached("createBinaryProfile()") ConditionProfile resultIsBuiltinProfile) { RBuiltinDescriptor builtin = builtinProfile.profile(function.getRBuiltin()); Object dispatchObject = dispatchArgument.execute(frame); - FrameSlot slot = dispatchTempSlot.initialize(frame, dispatchObject, () -> internalDispatchCall = null); + dispatchTempSlot.initialize(frame, dispatchObject, () -> internalDispatchCall = null); try { RStringVector type = classHierarchyNode.execute(dispatchObject); S3Args s3Args; @@ -362,7 +360,7 @@ public abstract class RCallNode extends RCallBaseNode implements RSyntaxNode, RS } return internalDispatchCall.execute(frame, resultFunction, lookupVarArgs(frame), s3Args, null); } finally { - dispatchTempSlot.cleanup(frame, slot); + dispatchTempSlot.cleanup(frame, dispatchObject); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/TemporarySlotNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/TemporarySlotNode.java index 56c0751202f6e38fef94f115bb4dcf895ce01efc..cab3f3c4da90b28fb840338a03006ebcd7bb9b31 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/TemporarySlotNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/function/TemporarySlotNode.java @@ -23,50 +23,53 @@ package com.oracle.truffle.r.nodes.function; import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.frame.FrameSlot; +import com.oracle.truffle.api.frame.FrameSlotKind; import com.oracle.truffle.api.frame.FrameSlotTypeException; import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.nodes.Node; -import com.oracle.truffle.r.nodes.access.FrameSlotNode; import com.oracle.truffle.r.runtime.RInternalError; public final class TemporarySlotNode extends Node { private static final Object[] defaultTempIdentifiers = new Object[]{new Object(), new Object(), new Object(), new Object(), new Object(), new Object(), new Object(), new Object()}; - @Child private FrameSlotNode tempSlot; + @CompilationFinal private FrameSlot tempSlot; private int tempIdentifier; private Object identifier; - public FrameSlot initialize(VirtualFrame frame, Object value, Runnable invalidate) { + public void initialize(VirtualFrame frame, Object value, Runnable invalidate) { if (tempSlot == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); - tempSlot = insert(FrameSlotNode.createInitialized(frame.getFrameDescriptor(), identifier = defaultTempIdentifiers[0], true)); + tempSlot = frame.getFrameDescriptor().findOrAddFrameSlot(identifier = defaultTempIdentifiers[0], FrameSlotKind.Object); invalidate.run(); } - FrameSlot slot = tempSlot.executeFrameSlot(frame); try { - if (frame.isObject(slot) && frame.getObject(slot) != null) { + if (frame.getObject(tempSlot) != null) { CompilerDirectives.transferToInterpreterAndInvalidate(); // keep the complete loop in the slow path do { tempIdentifier++; identifier = tempIdentifier < defaultTempIdentifiers.length ? defaultTempIdentifiers[tempIdentifier] : new Object(); - tempSlot.replace(FrameSlotNode.createInitialized(frame.getFrameDescriptor(), identifier, true)); + tempSlot = frame.getFrameDescriptor().findOrAddFrameSlot(identifier, FrameSlotKind.Object); invalidate.run(); - slot = tempSlot.executeFrameSlot(frame); - } while (frame.isObject(slot) && frame.getObject(slot) != null); + } while (frame.getObject(tempSlot) != null); } } catch (FrameSlotTypeException e) { + CompilerDirectives.transferToInterpreter(); throw RInternalError.shouldNotReachHere(); } - frame.setObject(slot, value); - return slot; + frame.setObject(tempSlot, value); } - @SuppressWarnings("static-method") - public void cleanup(VirtualFrame frame, FrameSlot slot) { - frame.setObject(slot, null); + public void cleanup(VirtualFrame frame, Object object) { + try { + assert frame.getObject(tempSlot) == object; + } catch (FrameSlotTypeException e) { + throw RInternalError.shouldNotReachHere(); + } + frame.setObject(tempSlot, null); } public Object getIdentifier() { diff --git a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNR_UserRng.java b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_UserRng.java similarity index 50% rename from com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNR_UserRng.java rename to com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_UserRng.java index ff50fa1416cd3935c2afbe2f0b9915aeb257a3db..fcb6c4c8d6f3ed744a912944823950b842892264 100644 --- a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNR_UserRng.java +++ b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_UserRng.java @@ -22,82 +22,41 @@ */ package com.oracle.truffle.r.runtime.ffi.jnr; +import static com.oracle.truffle.r.runtime.rng.user.UserRNG.Function; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.r.runtime.ffi.UserRngRFFI; -import jnr.ffi.LibraryLoader; -import jnr.ffi.Pointer; -import jnr.ffi.annotations.In; - -//Checkstyle: stop method name -public class JNR_UserRng implements UserRngRFFI { - public interface UserRng { - void user_unif_init(@In int seed); - - Pointer user_unif_rand(); - - Pointer user_unif_nseed(); - - Pointer user_unif_seedloc(); - } - - private static class UserRngProvider { - private static String libPath; - private static UserRng userRng; - - UserRngProvider(String libPath) { - UserRngProvider.libPath = libPath; - } - - @TruffleBoundary - private static UserRng createAndLoadLib() { - return LibraryLoader.create(UserRng.class).load(libPath); - } - - static UserRng userRng() { - if (userRng == null) { - userRng = createAndLoadLib(); - } - return userRng; - } - } - - private static UserRng userRng() { - return UserRngProvider.userRng(); - } - - @Override - @SuppressWarnings("unused") - public void setLibrary(String path) { - new UserRngProvider(path); - - } - +public class JNI_UserRng implements UserRngRFFI { @Override @TruffleBoundary public void init(int seed) { - userRng().user_unif_init(seed); + init(Function.Init.getAddress(), seed); + } @Override @TruffleBoundary public double rand() { - Pointer pDouble = userRng().user_unif_rand(); - return pDouble.getDouble(0); + return rand(Function.Rand.getAddress()); } @Override @TruffleBoundary public int nSeed() { - return userRng().user_unif_nseed().getInt(0); + return nSeed(Function.NSeed.getAddress()); } @Override @TruffleBoundary public void seeds(int[] n) { - Pointer pInt = userRng().user_unif_seedloc(); - for (int i = 0; i < n.length; i++) { - n[i] = pInt.getInt(i * 4); - } + seeds(Function.Seedloc.getAddress(), n); } + + private static native void init(long address, int seed); + + private static native double rand(long address); + + private static native int nSeed(long address); + + private static native void seeds(long address, int[] n); } diff --git a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNR_RFFIFactory.java b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNR_RFFIFactory.java index 05a093a010bd91d13ab53b115f77be56f86458b5..19aba1169678bd8cdfffd5d9d8e2c4c5ed5aec24 100644 --- a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNR_RFFIFactory.java +++ b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNR_RFFIFactory.java @@ -164,7 +164,7 @@ public class JNR_RFFIFactory extends RFFIFactory implements RFFI { public UserRngRFFI getUserRngRFFI() { if (userRngRFFI == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); - userRngRFFI = new JNR_UserRng(); + userRngRFFI = new JNI_UserRng(); } return userRngRFFI; } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/UserRngRFFI.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/UserRngRFFI.java index b3b7dc0dc3eb98c14bad2cc40154a93d2c95b23f..dc692727fcfe6b46100e854a859d63ec0abaa700 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/UserRngRFFI.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/UserRngRFFI.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014, 2014, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -23,13 +23,10 @@ package com.oracle.truffle.r.runtime.ffi; /** - * Explicit statically typed interface to user-supplied random number generators. TODO This could - * eventually be subsumed by {@link CRFFI}. + * Explicit statically typed interface to user-supplied random number generators. */ public interface UserRngRFFI { - void setLibrary(String path); - void init(int seed); double rand(); diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/user/UserRNG.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/user/UserRNG.java index bc649e2415f24676d88609143feee3c1f02631ab..9e1f745d769433b28be59b45398e497348bf8cce 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/user/UserRNG.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/user/UserRNG.java @@ -35,42 +35,68 @@ import com.oracle.truffle.r.runtime.rng.RRNG.Kind; * Interface to a user-supplied RNG. */ public final class UserRNG extends RNGInitAdapter { - - private static final String USER_UNIF_RAND = "user_unif_rand"; - private static final String USER_UNIF_INIT = "user_unif_init"; private static final boolean OPTIONAL = true; - @SuppressWarnings("unused") private long userUnifRand; - @SuppressWarnings("unused") private long userUnifInit; - private long userUnifNSeed; - private long userUnifSeedloc; + public enum Function { + Rand(!OPTIONAL), + Init(OPTIONAL), + NSeed(OPTIONAL), + Seedloc(OPTIONAL); + + private long address; + private final String symbol; + private final boolean optional; + + Function(boolean optional) { + this.symbol = "user_unif_" + name().toLowerCase(); + this.optional = optional; + } + + private boolean isDefined() { + return address != 0; + } + + public long getAddress() { + return address; + } + + private void setAddress(DLLInfo dllInfo) { + this.address = findSymbol(symbol, dllInfo, optional); + } + + } + private UserRngRFFI userRngRFFI; private int nSeeds = 0; @Override @TruffleBoundary public void init(int seed) { - DLLInfo dllInfo = DLL.findLibraryContainingSymbol(USER_UNIF_RAND); + DLLInfo dllInfo = DLL.findLibraryContainingSymbol(Function.Rand.symbol); if (dllInfo == null) { - throw RError.error(RError.NO_CALLER, RError.Message.RNG_SYMBOL, USER_UNIF_RAND); + throw RError.error(RError.NO_CALLER, RError.Message.RNG_SYMBOL, Function.Rand.symbol); + } + for (Function f : Function.values()) { + f.setAddress(dllInfo); } - userUnifRand = findSymbol(USER_UNIF_RAND, dllInfo, !OPTIONAL); - userUnifInit = findSymbol(USER_UNIF_INIT, dllInfo, OPTIONAL); - userUnifNSeed = findSymbol(USER_UNIF_INIT, dllInfo, OPTIONAL); - userUnifSeedloc = findSymbol(USER_UNIF_INIT, dllInfo, OPTIONAL); userRngRFFI = RFFIFactory.getRFFI().getUserRngRFFI(); - userRngRFFI.setLibrary(dllInfo.path); - userRngRFFI.init(seed); - if (userUnifSeedloc != 0 && userUnifNSeed == 0) { + if (Function.Init.isDefined()) { + userRngRFFI.init(seed); + } + if (Function.Seedloc.isDefined() && !Function.NSeed.isDefined()) { RError.warning(RError.NO_CALLER, RError.Message.RNG_READ_SEEDS); } - int ns = userRngRFFI.nSeed(); - if (ns < 0 || ns > 625) { - RError.warning(RError.NO_CALLER, RError.Message.GENERIC, "seed length must be in 0...625; ignored"); - } else { - nSeeds = ns; - // TODO: if we ever (initially) share iSeed (as GNU R does) we may need to assign this - // generator's iSeed here + if (Function.NSeed.isDefined()) { + int ns = userRngRFFI.nSeed(); + if (ns < 0 || ns > 625) { + RError.warning(RError.NO_CALLER, RError.Message.GENERIC, "seed length must be in 0...625; ignored"); + } else { + nSeeds = ns; + /* + * TODO: if we ever (initially) share iSeed (as GNU R does) we may need to assign + * this generator's iSeed here + */ + } } } @@ -95,7 +121,7 @@ public final class UserRNG extends RNGInitAdapter { @Override public int[] getSeeds() { - if (userUnifSeedloc == 0) { + if (!Function.Seedloc.isDefined()) { return null; } int[] result = new int[nSeeds];