diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java index 2a53a4239b9e3b65797b0698695c2f47b69a6a5a..e1c3ca08e3b0899498abbe29882a93dc486d7462 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java @@ -629,6 +629,7 @@ public class BasePackage extends RBuiltinPackage { add(Rank.class, RankNodeGen::create); add(RNGFunctions.RNGkind.class, RNGFunctionsFactory.RNGkindNodeGen::create); add(RNGFunctions.SetSeed.class, RNGFunctionsFactory.SetSeedNodeGen::create); + add(RNGFunctions.FastRSetSeed.class, RNGFunctionsFactory.FastRSetSeedNodeGen::create); add(RVersion.class, RVersionNodeGen::create); add(RawFunctions.CharToRaw.class, RawFunctionsFactory.CharToRawNodeGen::create); add(RawFunctions.RawToChar.class, RawFunctionsFactory.RawToCharNodeGen::create); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/R/base_overrides.R b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/R/base_overrides.R index 9b10332922a25ad440fb1fae00a2f2f92fb17920..3c6bcb3f79c803d615c04f8ada00097ca71bfbc7 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/R/base_overrides.R +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/R/base_overrides.R @@ -27,3 +27,5 @@ eval(expression({ } }) }), asNamespace("base")) + +makeActiveBinding(".Random.seed", .fastr.set.seed, .GlobalEnv) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RNGFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RNGFunctions.java index 146b768b00d6e5e06d36ed4f3d07fb6438f9ec89..6971216e1961bfc28dfbd1397dc3528e21372d2c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RNGFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RNGFunctions.java @@ -33,15 +33,22 @@ import static com.oracle.truffle.r.runtime.RError.Message.UNIMPLEMENTED_TYPE_IN_ import static com.oracle.truffle.r.runtime.RVisibility.OFF; import static com.oracle.truffle.r.runtime.builtins.RBehavior.MODIFIES_STATE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; +import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.builtin.NodeWithArgumentCasts.Casts; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import com.oracle.truffle.r.runtime.builtins.RBuiltinKind; +import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RIntVector; +import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.rng.RRNG; public class RNGFunctions { @@ -97,6 +104,43 @@ public class RNGFunctions { } } + @RBuiltin(name = ".fastr.set.seed", visibility = OFF, kind = PRIMITIVE, parameterNames = {"data"}, behavior = MODIFIES_STATE) + public abstract static class FastRSetSeed extends RBuiltinNode.Arg1 { + + static { + Casts.noCasts(FastRSetSeed.class); + } + + @Specialization + @TruffleBoundary + protected RNull setSeed(int[] data) { + RContext.getInstance().stateRNG.currentSeeds = data; + return RNull.instance; + } + + @Specialization + @TruffleBoundary + protected RNull setSeed(RAbstractIntVector data) { + int[] arr = new int[data.getLength()]; + for (int i = 0; i < arr.length; i++) { + arr[i] = data.getDataAt(i); + } + RContext.getInstance().stateRNG.currentSeeds = arr; + return RNull.instance; + } + + @Specialization + @TruffleBoundary + protected Object getSeed(@SuppressWarnings("unused") RMissing data) { + int[] seeds = RContext.getInstance().stateRNG.currentSeeds; + if (seeds != null) { + return RDataFactory.createIntVector(seeds, RDataFactory.INCOMPLETE_VECTOR); + } +// throw error(RError.Message.UNKNOWN_OBJECT, ".Random.seed"); + return RNull.instance; + } + } + private static final class CastsHelper { public static void kindInteger(Casts casts, String name, Message error, Object... messageArgs) { casts.arg(name).mapNull(constant(RRNG.NO_KIND_CHANGE)).mustBe(numericValue(), error, messageArgs).asIntegerVector().findFirst(); diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/RRNG.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/RRNG.java index eecc1c27b5f9536182fffd46fdddde09a55d3c1b..61aa34ed69756902342f3d6351b956ff91e14e22 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/RRNG.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/rng/RRNG.java @@ -17,6 +17,7 @@ import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RIntVector; @@ -25,6 +26,7 @@ import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RPromise; import com.oracle.truffle.r.runtime.data.RTypedValue; import com.oracle.truffle.r.runtime.env.REnvironment; +import com.oracle.truffle.r.runtime.env.frame.ActiveBinding; import com.oracle.truffle.r.runtime.ffi.BaseRFFI; import com.oracle.truffle.r.runtime.rng.mm.MarsagliaMulticarry; import com.oracle.truffle.r.runtime.rng.mt.MersenneTwister; @@ -107,6 +109,7 @@ public class RRNG { private RandomNumberGenerator currentGenerator; private final RandomNumberGenerator[] allGenerators; private NormKind currentNormKind; + public int[] currentSeeds; private ContextStateImpl() { this.currentNormKind = DEFAULT_NORM_KIND; @@ -282,11 +285,7 @@ public class RRNG { @TruffleBoundary private static Object getDotRandomSeed() { - Object seed = REnvironment.globalEnv().get(RANDOM_SEED); - if (seed instanceof RPromise) { - seed = RContext.getRRuntimeASTAccess().forcePromise(RANDOM_SEED, seed); - } - return seed; + return RContext.getInstance().stateRNG.currentSeeds; } /** @@ -327,6 +326,9 @@ public class RRNG { } else if (seeds instanceof RIntVector) { RIntVector seedsVec = (RIntVector) seeds; tmp = seedsVec.getLength() == 0 ? RRuntime.INT_NA : seedsVec.getDataAt(0); + } else if (seeds instanceof int[]) { + int[] seedsArr = (int[]) seeds; + tmp = seedsArr.length == 0 ? RRuntime.INT_NA : seedsArr[0]; } else { assert seeds != RMissing.instance; assert seeds instanceof RTypedValue; @@ -410,7 +412,8 @@ public class RRNG { public static void putRNGState() { int[] seeds = currentGenerator().getSeeds(); seeds[0] = currentKind().ordinal() + 100 * currentNormKind().ordinal(); - RIntVector vector = RDataFactory.createIntVector(seeds, RDataFactory.INCOMPLETE_VECTOR); - REnvironment.globalEnv().safePut(RANDOM_SEED, vector.makeSharedPermanent()); + RContext.getInstance().stateRNG.currentSeeds = seeds; +// RIntVector vector = RDataFactory.createIntVector(seeds, RDataFactory.INCOMPLETE_VECTOR); +// REnvironment.globalEnv().safePut(RANDOM_SEED, vector.makeSharedPermanent()); } }