diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SerializeFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SerializeFunctions.java index 8d1d96657307c3c77269954db15558cf88c23a05..a560f8f4a05a57ff7a664ac4c59e5f3c68d12f05 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SerializeFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SerializeFunctions.java @@ -23,6 +23,7 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asIntegerVector; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.eq; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.findFirst; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue; @@ -32,6 +33,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.IO; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import java.io.IOException; +import java.util.function.Function; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; @@ -69,8 +71,7 @@ public class SerializeFunctions { } @TruffleBoundary - protected static Object doSerializeToConnBase(RBaseNode node, Object object, int connIndex, int type, @SuppressWarnings("unused") byte xdrLogical, @SuppressWarnings("unused") RNull version, - @SuppressWarnings("unused") RNull refhook) { + protected static Object doSerializeToConnBase(RBaseNode node, Object object, int connIndex, int type) { // xdr is only relevant if ascii is false try (RConnection openConn = RConnection.fromIndex(connIndex).forceOpen(type != RSerialize.XDR ? "wt" : "wb")) { if (!openConn.canWrite()) { @@ -90,6 +91,12 @@ public class SerializeFunctions { casts.arg("con").mustBe(integerValue()).asIntegerVector().findFirst(); } + private static void version(Casts casts) { + // This just validates the value. It must be either default NULL or 2. Specializations + // should use 'Object' + casts.arg("version").allowNull().asIntegerVector().findFirst().mustBe(eq(2), Message.VERSION_N_NOT_SUPPORTED, (Function<Object, Object>) n -> n); + } + @RBuiltin(name = "unserializeFromConn", kind = INTERNAL, parameterNames = {"con", "refhook"}, behavior = IO) public abstract static class UnserializeFromConn extends RBuiltinNode.Arg2 { @@ -118,12 +125,12 @@ public class SerializeFunctions { casts.arg("object").mustNotBeMissing(); connection(casts); casts.arg("ascii").mustBe(logicalValue(), RError.Message.ASCII_NOT_LOGICAL).asLogicalVector().findFirst(); - casts.arg("version").allowNull().mustBe(integerValue()).asIntegerVector().findFirst(); + version(casts); casts.arg("refhook").mustNotBeMissing(); } @Specialization - protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, RNull version, RNull refhook) { + protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) { int type; if (asciiLogical == RRuntime.LOGICAL_NA) { type = RSerialize.ASCII_HEX; @@ -132,14 +139,7 @@ public class SerializeFunctions { } else { type = RSerialize.XDR; } - return doSerializeToConnBase(this, object, conn, type, RRuntime.LOGICAL_NA, version, refhook); - } - - @SuppressWarnings("unused") - @Specialization - protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, int version, Object refhook) { - // TODO: implement "version" support - throw RError.error(this, RError.Message.UNIMPLEMENTED_ARG_TYPE, 4); + return doSerializeToConnBase(this, object, conn, type); } } @@ -170,16 +170,16 @@ public class SerializeFunctions { Casts casts = new Casts(Serialize.class); casts.arg("con").allowNull().mustBe(integerValue()).asIntegerVector().findFirst(); casts.arg("type").asIntegerVector().findFirst(); + version(casts); } @Specialization - protected Object serialize(Object object, int conn, int type, RNull version, RNull refhook) { - return doSerializeToConnBase(this, object, conn, type, RRuntime.LOGICAL_NA, version, refhook); + protected Object serialize(Object object, int conn, int type, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) { + return doSerializeToConnBase(this, object, conn, type); } - @SuppressWarnings("unused") @Specialization - protected Object serialize(Object object, RNull conn, int type, RNull version, RNull refhook) { + protected Object serialize(Object object, RNull conn, int type, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) { byte[] data = RSerialize.serialize(object, type, RSerialize.DEFAULT_VERSION, null); return RDataFactory.createRawVector(data); } @@ -192,14 +192,15 @@ public class SerializeFunctions { Casts casts = new Casts(SerializeB.class); connection(casts); casts.arg("xdr").asLogicalVector().findFirst(); + version(casts); } @Specialization - protected Object serializeB(Object object, int conn, byte xdrLogical, RNull version, RNull refhook) { + protected Object serializeB(Object object, int conn, byte xdrLogical, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) { if (!RRuntime.fromLogical(xdrLogical)) { throw RError.nyi(this, "xdr==FALSE"); } - return doSerializeToConnBase(this, object, conn, RRuntime.LOGICAL_FALSE, xdrLogical, version, refhook); + return doSerializeToConnBase(this, object, conn, RRuntime.LOGICAL_FALSE); } } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java index 166064950f4095a202fc48b62bf2087af78098c9..13c2aa50eb7751e02745839b0a8d9f707ecc91cf 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java @@ -915,6 +915,7 @@ public final class RError extends RuntimeException implements TruffleException { BAD_CONSTANT_COUNT("bad constant count"), MUST_BE_MULTIPLE("argument '%s' must be a multiple of %d long"), MUSTNOT_CONTAIN_NAS("argument '%s' must not contain NAs"), + VERSION_N_NOT_SUPPORTED("version %d not supported"), ATOMIC_VECTOR_ARGUMENTS_ONLY("atomic vector arguments only"); public final String message; diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 1147bbb540d145862826d96d57cee320563fb62a..19c4c9d1609ce831a8b4ac1501ca9b8b823a8e23 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -67100,6 +67100,15 @@ integer(0) [26] fd 00 00 04 02 00 00 00 01 00 04 00 09 00 00 00 01 66 00 00 00 0a 00 00 00 [51] 01 80 00 00 00 00 00 00 fe 00 00 00 fe 00 00 00 fe +##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize# +#serialize('foo', NULL, version=2) + [1] 58 0a 00 00 00 02 00 03 04 00 00 02 03 00 00 00 00 10 00 00 00 01 00 04 00 +[26] 09 00 00 00 03 66 6f 6f + +##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize# +#serialize('foo', NULL, version=3) +Error in serialize("foo", NULL, version = 3) : version 3 not supported + ##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize# #setClass('foo', slots = c(x='numeric', y='numeric')); t1 <- new('foo', x=4, y=c(77,88)); options(keep.source=FALSE); serialize(t1, connection=NULL) [1] 58 0a 00 00 00 02 00 03 04 00 00 02 03 00 00 01 03 19 00 00 04 02 00 00 00 diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_serialize.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_serialize.java index e0ef86b8142d201c59940a44be130243039b8688..07b12af333d27fd0cff34d59e59c3c151129a8fd 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_serialize.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_serialize.java @@ -110,6 +110,8 @@ public class TestBuiltin_serialize extends TestBase { assertEval(Ignored.ImplementationError, "options(keep.source=FALSE); fc <- setClass('FooSerial1', representation(a = 'call')); serialize(fc, connection=NULL)"); assertEval("{ options(keep.source=FALSE); f <- function() NULL; attributes(f) <- list(skeleton=quote(`<undef>`())); data <- serialize(f, conn=NULL); unserialize(conn=data) }"); + assertEval("serialize('foo', NULL, version=2)"); + assertEval("serialize('foo', NULL, version=3)"); } @Test