Skip to content
Snippets Groups Projects
Commit 275bd697 authored by Stepan Sindelar's avatar Stepan Sindelar
Browse files

[GR-2798] Fix unsupported specialization in serialize.

PullRequest: fastr/1294
parents 6eb0d27e 78f44dc2
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,7 @@
package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asIntegerVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.eq;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.findFirst;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
......@@ -32,6 +33,7 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.IO;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import java.io.IOException;
import java.util.function.Function;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Specialization;
......@@ -69,8 +71,7 @@ public class SerializeFunctions {
}
@TruffleBoundary
protected static Object doSerializeToConnBase(RBaseNode node, Object object, int connIndex, int type, @SuppressWarnings("unused") byte xdrLogical, @SuppressWarnings("unused") RNull version,
@SuppressWarnings("unused") RNull refhook) {
protected static Object doSerializeToConnBase(RBaseNode node, Object object, int connIndex, int type) {
// xdr is only relevant if ascii is false
try (RConnection openConn = RConnection.fromIndex(connIndex).forceOpen(type != RSerialize.XDR ? "wt" : "wb")) {
if (!openConn.canWrite()) {
......@@ -90,6 +91,12 @@ public class SerializeFunctions {
casts.arg("con").mustBe(integerValue()).asIntegerVector().findFirst();
}
private static void version(Casts casts) {
// This just validates the value. It must be either default NULL or 2. Specializations
// should use 'Object'
casts.arg("version").allowNull().asIntegerVector().findFirst().mustBe(eq(2), Message.VERSION_N_NOT_SUPPORTED, (Function<Object, Object>) n -> n);
}
@RBuiltin(name = "unserializeFromConn", kind = INTERNAL, parameterNames = {"con", "refhook"}, behavior = IO)
public abstract static class UnserializeFromConn extends RBuiltinNode.Arg2 {
......@@ -118,12 +125,12 @@ public class SerializeFunctions {
casts.arg("object").mustNotBeMissing();
connection(casts);
casts.arg("ascii").mustBe(logicalValue(), RError.Message.ASCII_NOT_LOGICAL).asLogicalVector().findFirst();
casts.arg("version").allowNull().mustBe(integerValue()).asIntegerVector().findFirst();
version(casts);
casts.arg("refhook").mustNotBeMissing();
}
@Specialization
protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, RNull version, RNull refhook) {
protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
int type;
if (asciiLogical == RRuntime.LOGICAL_NA) {
type = RSerialize.ASCII_HEX;
......@@ -132,14 +139,7 @@ public class SerializeFunctions {
} else {
type = RSerialize.XDR;
}
return doSerializeToConnBase(this, object, conn, type, RRuntime.LOGICAL_NA, version, refhook);
}
@SuppressWarnings("unused")
@Specialization
protected Object doSerializeToConn(Object object, int conn, byte asciiLogical, int version, Object refhook) {
// TODO: implement "version" support
throw RError.error(this, RError.Message.UNIMPLEMENTED_ARG_TYPE, 4);
return doSerializeToConnBase(this, object, conn, type);
}
}
......@@ -170,16 +170,16 @@ public class SerializeFunctions {
Casts casts = new Casts(Serialize.class);
casts.arg("con").allowNull().mustBe(integerValue()).asIntegerVector().findFirst();
casts.arg("type").asIntegerVector().findFirst();
version(casts);
}
@Specialization
protected Object serialize(Object object, int conn, int type, RNull version, RNull refhook) {
return doSerializeToConnBase(this, object, conn, type, RRuntime.LOGICAL_NA, version, refhook);
protected Object serialize(Object object, int conn, int type, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
return doSerializeToConnBase(this, object, conn, type);
}
@SuppressWarnings("unused")
@Specialization
protected Object serialize(Object object, RNull conn, int type, RNull version, RNull refhook) {
protected Object serialize(Object object, RNull conn, int type, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
byte[] data = RSerialize.serialize(object, type, RSerialize.DEFAULT_VERSION, null);
return RDataFactory.createRawVector(data);
}
......@@ -192,14 +192,15 @@ public class SerializeFunctions {
Casts casts = new Casts(SerializeB.class);
connection(casts);
casts.arg("xdr").asLogicalVector().findFirst();
version(casts);
}
@Specialization
protected Object serializeB(Object object, int conn, byte xdrLogical, RNull version, RNull refhook) {
protected Object serializeB(Object object, int conn, byte xdrLogical, @SuppressWarnings("unused") Object version, @SuppressWarnings("unused") RNull refhook) {
if (!RRuntime.fromLogical(xdrLogical)) {
throw RError.nyi(this, "xdr==FALSE");
}
return doSerializeToConnBase(this, object, conn, RRuntime.LOGICAL_FALSE, xdrLogical, version, refhook);
return doSerializeToConnBase(this, object, conn, RRuntime.LOGICAL_FALSE);
}
}
}
......@@ -915,6 +915,7 @@ public final class RError extends RuntimeException implements TruffleException {
BAD_CONSTANT_COUNT("bad constant count"),
MUST_BE_MULTIPLE("argument '%s' must be a multiple of %d long"),
MUSTNOT_CONTAIN_NAS("argument '%s' must not contain NAs"),
VERSION_N_NOT_SUPPORTED("version %d not supported"),
ATOMIC_VECTOR_ARGUMENTS_ONLY("atomic vector arguments only");
public final String message;
......
......@@ -67100,6 +67100,15 @@ integer(0)
[26] fd 00 00 04 02 00 00 00 01 00 04 00 09 00 00 00 01 66 00 00 00 0a 00 00 00
[51] 01 80 00 00 00 00 00 00 fe 00 00 00 fe 00 00 00 fe
 
##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize#
#serialize('foo', NULL, version=2)
[1] 58 0a 00 00 00 02 00 03 04 00 00 02 03 00 00 00 00 10 00 00 00 01 00 04 00
[26] 09 00 00 00 03 66 6f 6f
##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize#
#serialize('foo', NULL, version=3)
Error in serialize("foo", NULL, version = 3) : version 3 not supported
##com.oracle.truffle.r.test.builtins.TestBuiltin_serialize.testserialize#
#setClass('foo', slots = c(x='numeric', y='numeric')); t1 <- new('foo', x=4, y=c(77,88)); options(keep.source=FALSE); serialize(t1, connection=NULL)
[1] 58 0a 00 00 00 02 00 03 04 00 00 02 03 00 00 01 03 19 00 00 04 02 00 00 00
......@@ -110,6 +110,8 @@ public class TestBuiltin_serialize extends TestBase {
assertEval(Ignored.ImplementationError, "options(keep.source=FALSE); fc <- setClass('FooSerial1', representation(a = 'call')); serialize(fc, connection=NULL)");
assertEval("{ options(keep.source=FALSE); f <- function() NULL; attributes(f) <- list(skeleton=quote(`<undef>`())); data <- serialize(f, conn=NULL); unserialize(conn=data) }");
assertEval("serialize('foo', NULL, version=2)");
assertEval("serialize('foo', NULL, version=3)");
}
@Test
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment