From 482fa57046da59b06dc4783fd0095dfa8f973d2c Mon Sep 17 00:00:00 2001 From: Florian Angerer <florian.angerer@oracle.com> Date: Fri, 24 Mar 2017 13:24:28 +0100 Subject: [PATCH] Implemented native function "R_new_custom_connection". Added class NativeConnections represented such connections and performing native calls for reading and writing. This changes make package "curl" working. --- .../fficall/src/jni/Connections.c | 176 +++++++++++- com.oracle.truffle.r.native/version.source | 2 +- .../builtin/base/ConnectionFunctions.java | 3 +- .../r/nodes/ffi/JavaUpCallsRFFIImpl.java | 26 +- .../truffle/r/nodes/ffi/RFFIUpCallMethod.java | 1 + .../r/nodes/ffi/TracingUpCallsRFFIImpl.java | 6 + .../r/runtime/conn/ConnectionSupport.java | 38 ++- .../r/runtime/conn/DelegateRConnection.java | 3 +- .../r/runtime/conn/NativeConnections.java | 250 ++++++++++++++++++ .../truffle/r/runtime/ffi/StdUpCallsRFFI.java | 2 + .../truffle/r/test/ExpectedTestOutput.test | 6 + .../r/test/library/base/TestConnections.java | 5 + 12 files changed, 496 insertions(+), 22 deletions(-) create mode 100644 com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/NativeConnections.java diff --git a/com.oracle.truffle.r.native/fficall/src/jni/Connections.c b/com.oracle.truffle.r.native/fficall/src/jni/Connections.c index 8b4fc50ddc..907e9a08e0 100644 --- a/com.oracle.truffle.r.native/fficall/src/jni/Connections.c +++ b/com.oracle.truffle.r.native/fficall/src/jni/Connections.c @@ -33,6 +33,7 @@ static jmethodID getConnClassMethodID; static jmethodID getSummaryDescMethodID; static jmethodID isSeekableMethodID; static jmethodID getOpenModeMethodID; +static jmethodID newCustomConnectionMethodID; static jbyteArray wrap(JNIEnv *thisenv, void* buf, size_t n) { jbyteArray barr = (*thisenv)->NewByteArray(thisenv, n); @@ -82,6 +83,9 @@ void init_connections(JNIEnv *env) { /* String getOpenModeString(BaseRConnection) */ getOpenModeMethodID = checkGetMethodID(env, UpCallsRFFIClass, "getOpenModeString", "(Ljava/lang/Object;)Ljava/lang/String;", 0); + + /* int R_new_custom_connection(String, String, String) */ + newCustomConnectionMethodID = checkGetMethodID(env, UpCallsRFFIClass, "R_new_custom_connection", "(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;J)Ljava/lang/Object;", 0); } static char *connStringToChars(JNIEnv *env, jstring string) { @@ -179,7 +183,177 @@ static void init_con(Rconnection new, char *description, int enc, } SEXP R_new_custom_connection(const char *description, const char *mode, const char *class_name, Rconnection *ptr) { - return unimplemented("R_new_custom_connection"); + JNIEnv *thisenv = getEnv(); + Rconnection new; + SEXP ans, class; + + new = (Rconnection) malloc(sizeof(struct Rconn)); + if (!new) + error(_("allocation of %s connection failed"), class_name); + + jstring jsDescription = (*thisenv)->NewStringUTF(thisenv, description); + jstring jsMode = (*thisenv)->NewStringUTF(thisenv, mode); + jstring jsClassName = (*thisenv)->NewStringUTF(thisenv, class_name); + printf("New conn addr = %llx\n", (jlong)new); + ans = (*thisenv)->CallObjectMethod(thisenv, UpCallsRFFIObject, newCustomConnectionMethodID, jsDescription, jsMode, jsClassName, (jlong)new); + printf("native fd = %d\n", asInteger(ans)); + if (ans) { + + new->class = (char *) malloc(strlen(class_name) + 1); + if (!new->class) { + free(new); + error(_("allocation of %s connection failed"), class_name); + } + strcpy(new->class, class_name); + new->description = (char *) malloc(strlen(description) + 1); + if (!new->description) { + free(new->class); + free(new); + error(_("allocation of %s connection failed"), class_name); + } + init_con(new, description, CE_NATIVE, mode); + /* all ptrs are init'ed to null_* so no need to repeat that, + but the following two are useful tools which could not be accessed otherwise */ + // TODO dummy_vfprintf and dummy_fgetc not implemented yet +// new->vfprintf = &dummy_vfprintf; +// new->fgetc = &dummy_fgetc; + + /* new->blocking = block; */ + new->encname[0] = 0; /* "" (should have the same effect as "native.enc") */ + new->ex_ptr = R_MakeExternalPtr(new->id, install("connection"), R_NilValue); + + class = allocVector(STRSXP, 2); + SET_STRING_ELT(class, 0, mkChar(class_name)); + SET_STRING_ELT(class, 1, mkChar("connection")); + classgets(ans, class); +// setAttrib(ans, R_ConnIdSymbol, new->ex_ptr); + + if (ptr) { + ptr[0] = new; + } + } + + return ans; +} + +/* + * The address of the Rconnection struct is passed to Java. + * Since down calls can only have Object parameters, we put the address into an int vector. + * Position 0 is the lower part and position 1 is the higher part of the address. + * This currently assumes max. 64-bit addresses ! + */ +static Rconnection convertToAddress(SEXP intVec) { + if(!Rf_isVector(intVec)) { + error(_("invalid address object")); + } + + // convert integer vector to array + int *arr = INTEGER(intVec); + + // bit-fiddle address + jlong ptr = (jlong) arr[1]; + ptr = (ptr<<32) | ((jlong)arr[0] & 0xFFFFFFFFl); + + return (Rconnection) ptr; +} + +/* + * This function is used as Java down call function to query the value of a connection's flag. + * DO NOT CHANGE ITS SIGNATURE ! + * If changing the signature is unavoidable, adapt it in class 'NativeConnections'. + */ +SEXP __GetFlagNativeConnection(SEXP rConnAddrObj, jstring jname) { + JNIEnv *thisenv = getEnv(); + Rconnection con = convertToAddress(rConnAddrObj); + const char *name = connStringToChars(thisenv, jname); + Rboolean result = 0; + + if(strcmp(name, "text") == 0) { + result = con->text; + } else if(strcmp(name, "isopen") == 0) { + result = con->isopen; + }else if(strcmp(name, "incomplete") == 0) { + result = con->incomplete; + }else if(strcmp(name, "canread") == 0) { + result = con->canread; + }else if(strcmp(name, "canwrite") == 0) { + result = con->canwrite; + }else if(strcmp(name, "canseek") == 0) { + result = con->canseek; + }else if(strcmp(name, "blocking") == 0) { + result = con->blocking; + } + free(name); + + return Rf_ScalarLogical(result); +} + +/* + * This function is used as Java down call function to invoke the open function of a natively created connection. + * DO NOT CHANGE ITS SIGNATURE ! + * If changing the signature is unavoidable, adapt it in class 'NativeConnections'. + */ +SEXP __OpenNativeConnection(SEXP rConnAddrObj) { + Rconnection con = convertToAddress(rConnAddrObj); + Rboolean success = con->open(con); + return Rf_ScalarLogical(success); +} + +/* + * This function is used as Java down call function to invoke the open function of a natively created connection. + * DO NOT CHANGE ITS SIGNATURE ! + * If changing the signature is unavoidable, adapt it in class 'NativeConnections'. + */ +SEXP __CloseNativeConnection(SEXP rConnAddrObj) { + Rconnection con = convertToAddress(rConnAddrObj); + con->close(con); + return NULL; +} + +/* + * This function is used as Java down call function to invoke the read function of a natively created connection. + * DO NOT CHANGE ITS SIGNATURE ! + * If changing the signature is unavoidable, adapt it in class 'NativeConnections'. + */ +SEXP __ReadNativeConnection(SEXP rConnAddrObj, jbyteArray bufObj, SEXP nVec) { + JNIEnv *thisenv = getEnv(); + int n = asInteger(nVec); + Rconnection con = convertToAddress(rConnAddrObj); + void *tmp_buf = (*thisenv)->GetByteArrayElements(thisenv, bufObj, NULL); + size_t nread = con->read(tmp_buf, 1, n, con); + // copy back and release buffer + (*thisenv)->ReleaseByteArrayElements(thisenv, bufObj, tmp_buf, JNI_COMMIT); + return Rf_ScalarInteger(nread); +} + +/* + * This function is used as Java down call function to invoke the write function of a natively created connection. + * DO NOT CHANGE ITS SIGNATURE ! + * If changing the signature is unavoidable, adapt it in class 'NativeConnections'. + */ +SEXP __WriteNativeConnection(SEXP rConnAddrObj, jbyteArray bufObj, SEXP nVec) { + JNIEnv *thisenv = getEnv(); + int n = asInteger(nVec); + Rconnection con = convertToAddress(rConnAddrObj); + void *bytes = (*thisenv)->GetByteArrayElements(thisenv, bufObj, NULL); + size_t nwritten = con->write(bytes, 1, n, con); + // just release buffer + (*thisenv)->ReleaseByteArrayElements(thisenv, bufObj, bytes, JNI_ABORT); + return Rf_ScalarInteger(nwritten); +} + +/* + * This function is used as Java down call function to invoke the seek function of a natively created connection. + * DO NOT CHANGE ITS SIGNATURE ! + * If changing the signature is unavoidable, adapt it in class 'NativeConnections'. + */ +SEXP __SeekNativeConnection(SEXP rConnAddrObj, SEXP whereObj, SEXP originObj, SEXP rwObj) { + Rconnection con = convertToAddress(rConnAddrObj); + double where = adDouble(whereObj); + int origin = asInteger(originObj); + int rw = asInteger(rwObj); + double oldPos = con->seek(con, where, origin, rw); + return Rf_ScalarReal(oldPos); } size_t R_ReadConnection(Rconnection con, void *buf, size_t n) { diff --git a/com.oracle.truffle.r.native/version.source b/com.oracle.truffle.r.native/version.source index 60d3b2f4a4..b6a7d89c68 100644 --- a/com.oracle.truffle.r.native/version.source +++ b/com.oracle.truffle.r.native/version.source @@ -1 +1 @@ -15 +16 diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java index e8b5fdc22f..24a803191f 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ConnectionFunctions.java @@ -515,7 +515,7 @@ public abstract class ConnectionFunctions { BaseRConnection baseCon = RConnection.fromIndex(object); Object[] data = new Object[NAMES.getLength()]; data[0] = baseCon.getSummaryDescription(); - data[1] = baseCon.getConnectionClass().getPrintName(); + data[1] = baseCon.getConnectionClass(); data[2] = baseCon.getOpenMode().summaryString(); data[3] = baseCon.getSummaryText(); data[4] = baseCon.isOpen() ? "opened" : "closed"; @@ -545,7 +545,6 @@ public abstract class ConnectionFunctions { } if (baseConn.isOpen()) { warning(RError.Message.ALREADY_OPEN_CONNECTION); - return RNull.instance; } baseConn.open(open); } catch (IOException ex) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/JavaUpCallsRFFIImpl.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/JavaUpCallsRFFIImpl.java index f16b3e1726..70bc2ec622 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/JavaUpCallsRFFIImpl.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/JavaUpCallsRFFIImpl.java @@ -55,6 +55,8 @@ import com.oracle.truffle.r.runtime.RStartParams.SA_TYPE; import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.Utils; import com.oracle.truffle.r.runtime.conn.ConnectionSupport.BaseRConnection; +import com.oracle.truffle.r.runtime.conn.ConnectionSupport.InvalidConnection; +import com.oracle.truffle.r.runtime.conn.NativeConnections.NativeRConnection; import com.oracle.truffle.r.runtime.conn.RConnection; import com.oracle.truffle.r.runtime.context.Engine.ParseException; import com.oracle.truffle.r.runtime.context.RContext; @@ -381,7 +383,10 @@ public abstract class JavaUpCallsRFFIImpl implements UpCallsRFFI { case LGLSXP: return RDataFactory.createLogicalVector(new byte[n], RDataFactory.COMPLETE_VECTOR); case STRSXP: - return RDataFactory.createStringVector(new String[n], RDataFactory.COMPLETE_VECTOR); + // fill list with empty strings + String[] data = new String[n]; + Arrays.fill(data, ""); + return RDataFactory.createStringVector(data, RDataFactory.COMPLETE_VECTOR); case CPLXSXP: return RDataFactory.createComplexVector(new double[2 * n], RDataFactory.COMPLETE_VECTOR); case RAWSXP: @@ -437,7 +442,9 @@ public abstract class JavaUpCallsRFFIImpl implements UpCallsRFFI { case LGLSXP: return RDataFactory.createLogicalVector(new byte[nrow * ncol], RDataFactory.COMPLETE_VECTOR, dims); case STRSXP: - return RDataFactory.createStringVector(new String[nrow * ncol], RDataFactory.COMPLETE_VECTOR, dims); + String[] data = new String[nrow * ncol]; + Arrays.fill(data, ""); + return RDataFactory.createStringVector(data, RDataFactory.COMPLETE_VECTOR, dims); case CPLXSXP: return RDataFactory.createComplexVector(new double[2 * (nrow * ncol)], RDataFactory.COMPLETE_VECTOR, dims); default: @@ -1196,6 +1203,19 @@ public abstract class JavaUpCallsRFFIImpl implements UpCallsRFFI { return RASTUtils.createLanguageElement(expr); } + @Override + public Object R_new_custom_connection(Object description, Object mode, Object className, long readAddr) { + // TODO handle encoding properly ! + String strDescription = (String) description; + String strMode = (String) mode; + String strClassName = (String) className; + try { + return new NativeRConnection(strDescription, strMode, strClassName, readAddr).asVector(); + } catch (IOException e) { + return InvalidConnection.instance.asVector(); + } + } + @Override public int R_ReadConnection(int fd, byte[] buf) { @@ -1233,7 +1253,7 @@ public abstract class JavaUpCallsRFFIImpl implements UpCallsRFFI { @Override public String getConnectionClassString(Object x) { BaseRConnection conn = guaranteeInstanceOf(x, BaseRConnection.class); - return conn.getConnectionClass().getPrintName(); + return conn.getConnectionClass(); } @Override diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/RFFIUpCallMethod.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/RFFIUpCallMethod.java index 3cb37d4a25..25d2f9f385 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/RFFIUpCallMethod.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/RFFIUpCallMethod.java @@ -91,6 +91,7 @@ public enum RFFIUpCallMethod { R_isEqual("(object, object) : sint32"), R_isGlobal("(object) : sint32"), R_lsInternal3("(object, sint32, sint32) : object"), + R_new_custom_connection("(string, string, string, object) : object"), R_tryEval("(object, object, object) : object"), Rf_GetOption1("(object) : object"), Rf_PairToVectorList("(object) : object"), diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/TracingUpCallsRFFIImpl.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/TracingUpCallsRFFIImpl.java index c5ec109c8c..7cdfff95c3 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/TracingUpCallsRFFIImpl.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/ffi/TracingUpCallsRFFIImpl.java @@ -741,6 +741,12 @@ final class TracingUpCallsRFFIImpl implements UpCallsRFFI { return delegate.R_CHAR(x); } + @Override + public Object R_new_custom_connection(Object description, Object mode, Object className, long readAddr) { + RFFIUtils.traceUpCall("R_new_custom_connection", description, mode, className, readAddr); + return delegate.R_new_custom_connection(description, mode, className, readAddr); + } + @Override public int R_ReadConnection(int fd, byte[] buf) { RFFIUtils.traceUpCall("R_ReadConnection", fd, buf); diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/ConnectionSupport.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/ConnectionSupport.java index 46a93b40dc..825aa4cd44 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/ConnectionSupport.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/ConnectionSupport.java @@ -335,7 +335,9 @@ public class ConnectionSupport { Internal("internal"), PIPE("pipe"), FIFO("fifo"), - CHANNEL("Java Channel"); + CHANNEL("Java Channel"), + NATIVE("custom"), + INVALID("invalid"); private final String printName; @@ -358,7 +360,11 @@ public class ConnectionSupport { } } - public static final class InvalidConnection implements RConnection { + public static final class InvalidConnection extends BaseRConnection { + + protected InvalidConnection() { + super(ConnectionClass.INVALID, null); + } public static final InvalidConnection instance = new InvalidConnection(); @@ -395,7 +401,7 @@ public class ConnectionSupport { } @Override - public RConnection forceOpen(String modeString) throws IOException { + public BaseRConnection forceOpen(String modeString) throws IOException { throw RInternalError.shouldNotReachHere("INVALID CONNECTION"); } @@ -470,22 +476,27 @@ public class ConnectionSupport { } @Override - public void pushBack(RAbstractStringVector lines, boolean addNewLine) { + public long seek(long offset, SeekMode seekMode, SeekRWMode seekRWMode) throws IOException { throw RInternalError.shouldNotReachHere("INVALID CONNECTION"); } @Override - public long seek(long offset, SeekMode seekMode, SeekRWMode seekRWMode) throws IOException { + public ByteChannel getChannel() throws IOException { throw RInternalError.shouldNotReachHere("INVALID CONNECTION"); } @Override - public ByteChannel getChannel() throws IOException { + public void truncate() throws IOException { throw RInternalError.shouldNotReachHere("INVALID CONNECTION"); } @Override - public void truncate() throws IOException { + protected void createDelegateConnection() throws IOException { + throw RInternalError.shouldNotReachHere("INVALID CONNECTION"); + } + + @Override + public String getSummaryDescription() { throw RInternalError.shouldNotReachHere("INVALID CONNECTION"); } } @@ -639,8 +650,8 @@ public class ConnectionSupport { this.openMode = mode; } - public final ConnectionClass getConnectionClass() { - return conClass; + public String getConnectionClass() { + return conClass.getPrintName(); } /** @@ -730,11 +741,10 @@ public class ConnectionSupport { * This is used exclusively by the {@code open} builtin. */ public void open(String modeString) throws IOException { - if (openMode.abstractOpenMode == AbstractOpenMode.Lazy) { - // modeString may override the default - openMode = new OpenMode(modeString == null ? openMode.modeString : modeString); + openMode = new OpenMode(modeString == null ? openMode.modeString : modeString); + if (!isOpen()) { + createDelegateConnection(); } - createDelegateConnection(); } protected void checkOpen() { @@ -1067,7 +1077,7 @@ public class ConnectionSupport { } public final RAbstractIntVector asVector() { - String[] classes = new String[]{ConnectionSupport.getBaseConnection(this).getConnectionClass().getPrintName(), "connection"}; + String[] classes = new String[]{ConnectionSupport.getBaseConnection(this).getConnectionClass(), "connection"}; RAbstractIntVector result = RDataFactory.createIntVector(new int[]{getDescriptor()}, true); diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/DelegateRConnection.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/DelegateRConnection.java index e31aaa0bd5..74bc8391a0 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/DelegateRConnection.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/DelegateRConnection.java @@ -549,7 +549,8 @@ abstract class DelegateRConnection implements RConnection, ByteChannel { @Override public int readBin(ByteBuffer buffer) throws IOException { - return read(buffer); + int read = read(buffer); + return read < 0 ? 0 : read; } /** diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/NativeConnections.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/NativeConnections.java new file mode 100644 index 0000000000..531455c3a2 --- /dev/null +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/conn/NativeConnections.java @@ -0,0 +1,250 @@ +package com.oracle.truffle.r.runtime.conn; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import com.oracle.truffle.api.RootCallTarget; +import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.conn.ConnectionSupport.AbstractOpenMode; +import com.oracle.truffle.r.runtime.conn.ConnectionSupport.BaseRConnection; +import com.oracle.truffle.r.runtime.conn.ConnectionSupport.ConnectionClass; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.RIntVector; +import com.oracle.truffle.r.runtime.data.RLogicalVector; +import com.oracle.truffle.r.runtime.ffi.CallRFFI; +import com.oracle.truffle.r.runtime.ffi.DLL; +import com.oracle.truffle.r.runtime.ffi.DLL.DLLInfo; +import com.oracle.truffle.r.runtime.ffi.DLL.SymbolHandle; +import com.oracle.truffle.r.runtime.ffi.NativeCallInfo; + +/** + * Represents a custom connection created in native code and having its own native read and write + * functions. + */ +public class NativeConnections { + + private static final String OPEN_NATIVE_CONNECTION = "__OpenNativeConnection"; + private static final String CLOSE_NATIVE_CONNECTION = "__CloseNativeConnection"; + private static final String READ_NATIVE_CONNECTION = "__ReadNativeConnection"; + private static final String WRITE_NATIVE_CONNECTION = "__WriteNativeConnection"; + private static final String GET_FLAG_NATIVE_CONNECTION = "__GetFlagNativeConnection"; + private static final String SEEK_NATIVE_CONNECTION = "__SeekNativeConnection"; + + private static final Map<String, NativeCallInfo> callInfoTable = new HashMap<>(4); + + private static NativeCallInfo getNativeFunctionInfo(String name) { + NativeCallInfo nativeCallInfo = callInfoTable.get(name); + if (nativeCallInfo == null) { + DLLInfo findLibraryContainingSymbol = DLL.findLibraryContainingSymbol(name); + SymbolHandle findSymbol = DLL.findSymbol(name, findLibraryContainingSymbol); + nativeCallInfo = new NativeCallInfo(name, findSymbol, findLibraryContainingSymbol); + callInfoTable.put(name, nativeCallInfo); + } + return nativeCallInfo; + } + + public static class NativeRConnection extends BaseRConnection { + private final String customConClass; + private final String description; + private final long addr; + + public NativeRConnection(String description, String modeString, String customConClass, long addr) throws IOException { + super(ConnectionClass.NATIVE, modeString, AbstractOpenMode.Read); + this.customConClass = Objects.requireNonNull(customConClass); + this.description = Objects.requireNonNull(description); + this.addr = addr; + } + + @Override + protected void createDelegateConnection() throws IOException { + DelegateRConnection delegate = null; + switch (getOpenMode().abstractOpenMode) { + case Read: + case ReadBinary: + delegate = new ReadNativeConnection(this); + break; + case Write: + case WriteBinary: + delegate = new WriteNativeConnection(this); + break; + } + setDelegate(delegate); + } + + @Override + public String getSummaryDescription() { + return description; + } + + @Override + public String getConnectionClass() { + return customConClass; + } + + public long getNativeAddress() { + return addr; + } + + public boolean getFlag(String name) { + NativeCallInfo ni = NativeConnections.getNativeFunctionInfo(GET_FLAG_NATIVE_CONNECTION); + RootCallTarget nativeCallTarget = CallRFFI.InvokeCallRootNode.create().getCallTarget(); + + RIntVector addrVec = convertAddrToIntVec(addr); + Object result = nativeCallTarget.call(ni, new Object[]{addrVec, name}); + if (result instanceof RLogicalVector) { + return ((RLogicalVector) result).getDataAt(0) == RRuntime.LOGICAL_TRUE; + } + throw new RInternalError("could not get value of flag " + name); + } + } + + static class ReadNativeConnection extends DelegateReadRConnection { + + private final ByteChannel ch; + + protected ReadNativeConnection(NativeRConnection base) throws IOException { + super(base, 4096); + ch = new NativeChannel(base); + NativeConnections.openNative(base.addr); + } + + @Override + public ByteChannel getChannel() throws IOException { + return ch; + } + + @Override + protected long seekInternal(long offset, SeekMode seekMode, SeekRWMode seekRWMode) throws IOException { + RDoubleVector where = RDataFactory.createDoubleVectorFromScalar(offset); + RIntVector seekCode; + switch (seekMode) { + case CURRENT: + seekCode = RDataFactory.createIntVectorFromScalar(1); + break; + case END: + seekCode = RDataFactory.createIntVectorFromScalar(2); + break; + case START: + seekCode = RDataFactory.createIntVectorFromScalar(0); + break; + default: + seekCode = RDataFactory.createIntVectorFromScalar(-1); + break; + } + RIntVector rwCode = RDataFactory.createIntVectorFromScalar(1); + + NativeCallInfo ni = NativeConnections.getNativeFunctionInfo(SEEK_NATIVE_CONNECTION); + RootCallTarget nativeCallTarget = CallRFFI.InvokeCallRootNode.create().getCallTarget(); + RIntVector addrVec = convertAddrToIntVec(((NativeRConnection) base).addr); + Object result = nativeCallTarget.call(ni, new Object[]{addrVec, where, seekCode, rwCode}); + if (result instanceof RDoubleVector) { + return (long) ((RDoubleVector) result).getDataAt(0); + } + throw RInternalError.shouldNotReachHere("unexpected result type"); + } + + @Override + public boolean isSeekable() { + return ((NativeRConnection) base).getFlag("canseek"); + } + } + + static class WriteNativeConnection extends DelegateWriteRConnection { + + private final ByteChannel ch; + + protected WriteNativeConnection(NativeRConnection base) throws IOException { + super(base); + ch = new NativeChannel(base); + NativeConnections.openNative(base.addr); + } + + @Override + public ByteChannel getChannel() throws IOException { + return ch; + } + + @Override + public boolean isSeekable() { + return ((NativeRConnection) base).getFlag("canseek"); + } + } + + private static void openNative(long addr) throws IOException { + NativeCallInfo ni = NativeConnections.getNativeFunctionInfo(OPEN_NATIVE_CONNECTION); + RootCallTarget nativeCallTarget = CallRFFI.InvokeCallRootNode.create().getCallTarget(); + + RIntVector addrVec = convertAddrToIntVec(addr); + Object result = nativeCallTarget.call(ni, new Object[]{addrVec}); + if (!(result instanceof RLogicalVector && ((RLogicalVector) result).getDataAt(0) == RRuntime.LOGICAL_TRUE)) { + throw new IOException("could not open connection"); + } + } + + private static class NativeChannel implements ByteChannel { + + private final NativeRConnection base; + + public NativeChannel(NativeRConnection base) { + this.base = base; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + NativeCallInfo ni = NativeConnections.getNativeFunctionInfo(READ_NATIVE_CONNECTION); + RootCallTarget nativeCallTarget = CallRFFI.InvokeCallRootNode.create().getCallTarget(); + RIntVector vec = NativeConnections.convertAddrToIntVec(base.addr); + Object call = nativeCallTarget.call(ni, new Object[]{vec, dst.array(), dst.remaining()}); + + if (call instanceof RIntVector) { + int nread = ((RIntVector) call).getDataAt(0); + // update buffer's position ! + dst.position(nread); + return nread; + } + + throw RInternalError.shouldNotReachHere("unexpected result type"); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() throws IOException { + NativeCallInfo ni = NativeConnections.getNativeFunctionInfo(CLOSE_NATIVE_CONNECTION); + RootCallTarget nativeCallTarget = CallRFFI.InvokeCallRootNode.create().getCallTarget(); + RIntVector vec = NativeConnections.convertAddrToIntVec(base.addr); + nativeCallTarget.call(ni, new Object[]{vec}); + } + + @Override + public int write(ByteBuffer src) throws IOException { + NativeCallInfo ni = NativeConnections.getNativeFunctionInfo(WRITE_NATIVE_CONNECTION); + RootCallTarget nativeCallTarget = CallRFFI.InvokeCallRootNode.create().getCallTarget(); + + RIntVector vec = NativeConnections.convertAddrToIntVec(base.addr); + Object result = nativeCallTarget.call(ni, new Object[]{vec, src.array(), src.remaining()}); + + if (result instanceof RIntVector) { + return ((RIntVector) result).getDataAt(0); + } + + throw RInternalError.shouldNotReachHere("unexpected result type"); + } + } + + static RIntVector convertAddrToIntVec(long addr) { + int high = (int) (addr >> 32); + int low = (int) addr; + RIntVector vec = RDataFactory.createIntVector(new int[]{low, high}, true); + return vec; + } +} diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/StdUpCallsRFFI.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/StdUpCallsRFFI.java index 86f784db14..82a1e81ff4 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/StdUpCallsRFFI.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/StdUpCallsRFFI.java @@ -254,6 +254,8 @@ public interface StdUpCallsRFFI { Object R_CHAR(Object x); + Object R_new_custom_connection(@RFFICstring Object description, @RFFICstring Object mode, @RFFICstring Object className, long readAddr); + int R_ReadConnection(int fd, byte[] buf); int R_WriteConnection(int fd, byte[] buf); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index ec06e4b77c..eeaec0855d 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -77831,6 +77831,12 @@ In readLines(zz, 2, warn = T, skipNul = F) : #{ zz <- file('',"w+b", blocking=T); writeBin(as.raw(c(97,98,99,100,0,101,10,65,66,67,10)), zz, useBytes=T); seek(zz, 0); res <- readLines(zz, 2, warn=T, skipNul=T); close(zz); res } [1] "abcde" "ABC" +##com.oracle.truffle.r.test.library.base.TestConnections.testReopen# +#{ con <- rawConnection(charToRaw('hello\nworld\n')); readLines(con, 1); open(con, 'rb'); bin <- readBin(con, raw(), 999); close(con); rawToChar(bin) } +[1] "world\n" +Warning message: +In open.connection(con, "rb") : connection is already open + ##com.oracle.truffle.r.test.library.base.TestConnections.testSeekTextConnection# #{ zz <- textConnection("Hello, World!"); res <- isSeekable(zz); close(zz); res } [1] FALSE diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConnections.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConnections.java index d96b63d0c3..7b8806d22c 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConnections.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestConnections.java @@ -176,6 +176,11 @@ public class TestConnections extends TestRBase { assertEval("zz <- file('', 'w+'); summary(zz); close(zz)"); } + @Test + public void testReopen() { + assertEval("{ con <- rawConnection(charToRaw('hello\\nworld\\n')); readLines(con, 1); open(con, 'rb'); bin <- readBin(con, raw(), 999); close(con); rawToChar(bin) }"); + } + @Test public void testFileOpenRaw() { Assert.assertTrue("Could not create required temp file for test.", Files.exists(tempFileGzip)); -- GitLab