Skip to content
Snippets Groups Projects
Commit 78f44dc2 authored by stepan's avatar stepan
Browse files

Fix unsupported specialization in serialize

parent b016ac9d
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@
package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asIntegerVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.eq;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.findFirst;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
......@@ -32,6 +33,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.IO;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import java.io.IOException;
import java.util.function.Function;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Specialization;
......@@ -69,8 +71,7 @@ public class SerializeFunctions {
}
@TruffleBoundary
protected static Object doSerializeToConnBase(RBaseNode node, Object object, int connIndex, int type, @SuppressWarnings("unused") byte xdrLogical, @SuppressWarnings("unused") RNull version,
@SuppressWarnings("unused") RNull refhook) {
protected static Object doSerializeToConnBase(RBaseNode node, Object object, int connIndex, int type) {
// xdr is only relevant if ascii is false
try (RConnection openConn = RConnection.fromIndex(connIndex).forceOpen(type != RSerialize.XDR ? "wt" : "wb")) {
if (!openConn.canWrite()) {
......@@ -90,6 +91,12 @@ public class SerializeFunctions {
casts.arg("con").mustBe(integerValue()).asIntegerVector().findFirst();
}
private static void version(Casts casts) {
// This just validates the value. It must be either default NULL or 2. Specializations
// should use 'Object'
casts.arg("version").allowNull().asIntegerVector().findFirst().mustBe(eq(2), Message.VERSION_N_NOT_SUPPORTED, (Function<Object, Object>) n -> n);
}
@RBuiltin(name = "unserializeFromConn", kind = INTERNAL, parameterNames = {"con", "refhook"}, behavior = IO)
public abstract static class UnserializeFromConn extends RBuiltinNode.Arg2 {
......@@ -118,12 +125,12 @@ public class SerializeFunctions {
casts.arg("object").mustNotBeMissing();
connection(casts);
casts.arg("ascii").mustBe(logicalValue(), RError.Message.ASCII_NOT_LOGICAL).asLogicalVector().findFirst();
casts.arg("version").allowNull().mustBe(integerValue()).asIntegerVector().findFirst();
version(casts);
casts.arg("refhook").mustNotBeMissing();
}
@Specialization
protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, RNull version, RNull refhook) {
protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
int type;
if (asciiLogical == RRuntime.LOGICAL_NA) {
type = RSerialize.ASCII_HEX;
......@@ -132,14 +139,7 @@ public class SerializeFunctions {
} else {
type = RSerialize.XDR;
}
return doSerializeToConnBase(this, object, conn, type, RRuntime.LOGICAL_NA, version, refhook);
}
@SuppressWarnings("unused")
@Specialization
protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, int version, Object refhook) {
// TODO: implement "version" support
throw RError.error(this, RError.Message.UNIMPLEMENTED_ARG_TYPE, 4);
return doSerializeToConnBase(this, object, conn, type);
}
}
......@@ -170,16 +170,16 @@ public class SerializeFunctions {
Casts casts = new Casts(Serialize.class);
casts.arg("con").allowNull().mustBe(integerValue()).asIntegerVector().findFirst();
casts.arg("type").asIntegerVector().findFirst();
version(casts);
}
@Specialization
protected Object serialize(Object object, int conn, int type, RNull version, RNull refhook) {
return doSerializeToConnBase(this, object, conn, type, RRuntime.LOGICAL_NA, version, refhook);
protected Object serialize(Object object, int conn, int type, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
return doSerializeToConnBase(this, object, conn, type);
}
@SuppressWarnings("unused")
@Specialization
protected Object serialize(Object object, RNull conn, int type, RNull version, RNull refhook) {
protected Object serialize(Object object, RNull conn, int type, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
byte[] data = RSerialize.serialize(object, type, RSerialize.DEFAULT_VERSION, null);
return RDataFactory.createRawVector(data);
}
......@@ -192,14 +192,15 @@ public class SerializeFunctions {
Casts casts = new Casts(SerializeB.class);
connection(casts);
casts.arg("xdr").asLogicalVector().findFirst();
version(casts);
}
@Specialization
protected Object serializeB(Object object, int conn, byte xdrLogical, RNull version, RNull refhook) {
protected Object serializeB(Object object, int conn, byte xdrLogical, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
if (!RRuntime.fromLogical(xdrLogical)) {
throw RError.nyi(this, "xdr==FALSE");
}
return doSerializeToConnBase(this, object, conn, RRuntime.LOGICAL_FALSE, xdrLogical, version, refhook);
return doSerializeToConnBase(this, object, conn, RRuntime.LOGICAL_FALSE);
}
}
}
......@@ -915,6 +915,7 @@ public final class RError extends RuntimeException implements TruffleException {
BAD_CONSTANT_COUNT("bad constant count"),
MUST_BE_MULTIPLE("argument '%s' must be a multiple of %d long"),
MUSTNOT_CONTAIN_NAS("argument '%s' must not contain NAs"),
VERSION_N_NOT_SUPPORTED("version %d not supported"),
ATOMIC_VECTOR_ARGUMENTS_ONLY("atomic vector arguments only");
public final String message;
......
......@@ -67052,6 +67052,15 @@ integer(0)
[26] fd 00 00 04 02 00 00 00 01 00 04 00 09 00 00 00 01 66 00 00 00 0a 00 00 00
[51] 01 80 00 00 00 00 00 00 fe 00 00 00 fe 00 00 00 fe
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize#
#serialize('foo', NULL, version=2)
[1] 58 0a 00 00 00 02 00 03 04 00 00 02 03 00 00 00 00 10 00 00 00 01 00 04 00
[26] 09 00 00 00 03 66 6f 6f
##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize#
#serialize('foo', NULL, version=3)
Error in serialize("foo", NULL, version = 3) : version 3 not supported
##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize#
#setClass('foo', slots = c(x='numeric', y='numeric')); t1 <- new('foo', x=4, y=c(77,88)); options(keep.source=FALSE); serialize(t1, connection=NULL)
[1] 58 0a 00 00 00 02 00 03 04 00 00 02 03 00 00 01 03 19 00 00 04 02 00 00 00
......@@ -110,6 +110,8 @@ public class TestBuiltin_serialize extends TestBase {
assertEval(Ignored.ImplementationError, "options(keep.source=FALSE); fc <- setClass('FooSerial1', representation(a = 'call')); serialize(fc, connection=NULL)");
assertEval("{ options(keep.source=FALSE); f <- function() NULL; attributes(f) <- list(skeleton=quote(`<undef>`())); data <- serialize(f, conn=NULL); unserialize(conn=data) }");
assertEval("serialize('foo', NULL, version=2)");
assertEval("serialize('foo', NULL, version=3)");
}
@Test
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment