Skip to content
Snippets Groups Projects
Commit d8fd6cc8 authored by Mick Jordan's avatar Mick Jordan
Browse files

FFI: Refactor native array handling to be correct wrt nested calls

parent 9f5247a7
Branches
No related tags found
No related merge requests found
...@@ -75,10 +75,10 @@ static jmethodID SYMVALUE_MethodID; ...@@ -75,10 +75,10 @@ static jmethodID SYMVALUE_MethodID;
static jmethodID SET_SYMVALUE_MethodID; static jmethodID SET_SYMVALUE_MethodID;
static jmethodID SET_STRING_ELT_MethodID; static jmethodID SET_STRING_ELT_MethodID;
static jmethodID SET_VECTOR_ELT_MethodID; static jmethodID SET_VECTOR_ELT_MethodID;
static jmethodID RAW_MethodID; jmethodID RAW_MethodID;
static jmethodID INTEGER_MethodID; jmethodID INTEGER_MethodID;
static jmethodID REAL_MethodID; jmethodID REAL_MethodID;
static jmethodID LOGICAL_MethodID; jmethodID LOGICAL_MethodID;
static jmethodID STRING_ELT_MethodID; static jmethodID STRING_ELT_MethodID;
static jmethodID VECTOR_ELT_MethodID; static jmethodID VECTOR_ELT_MethodID;
static jmethodID LENGTH_MethodID; static jmethodID LENGTH_MethodID;
...@@ -526,24 +526,25 @@ int Rf_nrows(SEXP x) { ...@@ -526,24 +526,25 @@ int Rf_nrows(SEXP x) {
SEXP Rf_protect(SEXP x) { SEXP Rf_protect(SEXP x) {
TRACE(TARGp, x);
return x; return x;
} }
void Rf_unprotect(int x) { void Rf_unprotect(int x) {
// TODO perhaps we can use this TRACE(TARGp, x);
} }
void R_ProtectWithIndex(SEXP x, PROTECT_INDEX *y) { void R_ProtectWithIndex(SEXP x, PROTECT_INDEX *y) {
TRACE(TARGpd, x,y);
} }
void R_Reprotect(SEXP x, PROTECT_INDEX y) { void R_Reprotect(SEXP x, PROTECT_INDEX y) {
TRACE(TARGpd, x,y);
} }
void Rf_unprotect_ptr(SEXP x) { void Rf_unprotect_ptr(SEXP x) {
// TODO perhaps we can use this TRACE(TARGp, x);
} }
#define BUFSIZE 8192 #define BUFSIZE 8192
...@@ -1101,58 +1102,28 @@ int SETLEVELS(SEXP x, int v){ ...@@ -1101,58 +1102,28 @@ int SETLEVELS(SEXP x, int v){
int *LOGICAL(SEXP x){ int *LOGICAL(SEXP x){
TRACE(TARGp, x); TRACE(TARGp, x);
JNIEnv *thisenv = getEnv(); JNIEnv *thisenv = getEnv();
jint *data = (jint *) findCopiedObject(thisenv, x); jint *data = (jint *) getNativeArray(thisenv, x, LGLSXP);
if (data == NULL) {
jbyteArray byteArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, LOGICAL_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, byteArray);
jbyte* internalData = (*thisenv)->GetByteArrayElements(thisenv, byteArray, NULL);
data = malloc(len * sizeof(int));
for (int i = 0; i < len; i++) {
char value = internalData[i];
data[i] = value == 0 ? FALSE : value == 1 ? TRUE : NA_INTEGER;
}
(*thisenv)->ReleaseByteArrayElements(thisenv, byteArray, internalData, JNI_ABORT);
addCopiedObject(thisenv, x, LGLSXP, byteArray, data);
}
return data; return data;
} }
int *INTEGER(SEXP x){ int *INTEGER(SEXP x){
TRACE(TARGp, x); TRACE(TARGp, x);
JNIEnv *thisenv = getEnv(); JNIEnv *thisenv = getEnv();
jint *data = (jint *) findCopiedObject(thisenv, x); jint *data = (jint *) getNativeArray(thisenv, x, INTSXP);
if (data == NULL) {
jintArray intArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, INTEGER_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, intArray);
data = (*thisenv)->GetIntArrayElements(thisenv, intArray, NULL);
addCopiedObject(thisenv, x, INTSXP, intArray, data);
}
return data; return data;
} }
Rbyte *RAW(SEXP x){ Rbyte *RAW(SEXP x){
JNIEnv *thisenv = getEnv(); JNIEnv *thisenv = getEnv();
jbyte *data = (jbyte *) findCopiedObject(thisenv, x); Rbyte *data = (Rbyte*) getNativeArray(thisenv, x, RAWSXP);
if (data == NULL) { return data;
jbyteArray byteArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, RAW_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, byteArray);
data = (*thisenv)->GetByteArrayElements(thisenv, byteArray, NULL);
addCopiedObject(thisenv, x, RAWSXP, byteArray, data);
}
return (Rbyte*) data;
} }
double *REAL(SEXP x){ double *REAL(SEXP x){
JNIEnv *thisenv = getEnv(); JNIEnv *thisenv = getEnv();
jdouble *data = (jdouble *) findCopiedObject(thisenv, x); jdouble *data = (jdouble *) getNativeArray(thisenv, x, REALSXP);
if (data == NULL) {
jdoubleArray doubleArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, REAL_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, doubleArray);
data = (*thisenv)->GetDoubleArrayElements(thisenv, doubleArray, NULL);
addCopiedObject(thisenv, x, REALSXP, doubleArray, data);
}
return data; return data;
} }
......
...@@ -30,8 +30,7 @@ ...@@ -30,8 +30,7 @@
* that needs to be saved for reuse in the many R functions such as Rf_allocVector. * that needs to be saved for reuse in the many R functions such as Rf_allocVector.
* Currently only single threaded access is permitted (via a semaphore in CallRFFIWithJNI) * Currently only single threaded access is permitted (via a semaphore in CallRFFIWithJNI)
* so we are safe to use static variables. TODO Figure out where to store such state * so we are safe to use static variables. TODO Figure out where to store such state
* (portably) for MT use. JNI provides no help. N.B. The MT restriction also precludes * (portably) for MT use. JNI provides no help.
* recursive calls.
*/ */
jclass CallRFFIHelperClass; jclass CallRFFIHelperClass;
jclass RDataFactoryClass; jclass RDataFactoryClass;
...@@ -54,26 +53,37 @@ static int alwaysUseGlobal = 0; ...@@ -54,26 +53,37 @@ static int alwaysUseGlobal = 0;
static SEXP *cachedGlobalRefs; static SEXP *cachedGlobalRefs;
static int cachedGlobalRefsLength; static int cachedGlobalRefsLength;
typedef struct CopiedVectors_struct { // Data structure for managing the required copying of
// Java arrays to return C arrays, e.g, int*.
// N.B. There are actually two levels to this as FastR
// wraps, e.g., int[] in an RIntVector.
typedef struct nativeArrayTable_struct {
SEXPTYPE type; SEXPTYPE type;
SEXP obj; SEXP obj; // The jobject (SEXP) that data is derived from (e.g, RIntVector)
void *jArray; void *jArray; // the jarray corresponding to obj
void *data; void *data; // the (possibly) copied (or pinned) data from JNI GetXXXArrayElements
} CopiedVector; } NativeArrayElem;
#define COPIED_VECTORS_INITIAL_SIZE 64 #define NATIVE_ARRAY_TABLE_INITIAL_SIZE 64
// A table of vectors that have been accessed and whose contents, e.g. the actual data // A table of vectors that have been accessed and whose contents, e.g. the actual data
// as a primitive array have been copied and handed out to the native code. // as a primitive array have been copied and handed out to the native code.
static CopiedVector *copiedVectors; static NativeArrayElem *nativeArrayTable;
// hwm of copiedVectors // hwm of nativeArrayTable
static int copiedVectorsIndex; static int nativeArrayTableHwm;
static int copiedVectorsLength; static int nativeArrayTableLength;
static void releaseNativeArray(JNIEnv *env, int index);
static int isEmbedded = 0; static int isEmbedded = 0;
void setEmbedded() { void setEmbedded() {
isEmbedded = 1; isEmbedded = 1;
} }
// native down call depth, indexes nativeArrayTableHwmStack
int callDepth;
#define NATIVE_ARRAY_TABLE_HWM_STACK_SIZE 16
int nativeArrayTableHwmStack[NATIVE_ARRAY_TABLE_HWM_STACK_SIZE] ;
void init_utils(JNIEnv *env) { void init_utils(JNIEnv *env) {
curenv = env; curenv = env;
if (TRACE_ENABLED && traceFile == NULL) { if (TRACE_ENABLED && traceFile == NULL) {
...@@ -92,6 +102,8 @@ void init_utils(JNIEnv *env) { ...@@ -92,6 +102,8 @@ void init_utils(JNIEnv *env) {
fprintf(stderr, "%s, %d", "failed to fdopen trace file on JNI side\n", errno); fprintf(stderr, "%s, %d", "failed to fdopen trace file on JNI side\n", errno);
exit(1); exit(1);
} }
// no buffering
setvbuf(traceFile, (char*) NULL, _IONBF, 0);
} }
} }
RDataFactoryClass = checkFindClass(env, "com/oracle/truffle/r/runtime/data/RDataFactory"); RDataFactoryClass = checkFindClass(env, "com/oracle/truffle/r/runtime/data/RDataFactory");
...@@ -103,9 +115,9 @@ void init_utils(JNIEnv *env) { ...@@ -103,9 +115,9 @@ void init_utils(JNIEnv *env) {
validateMethodID = checkGetMethodID(env, CallRFFIHelperClass, "validate", "(Ljava/lang/Object;)Ljava/lang/Object;", 1); validateMethodID = checkGetMethodID(env, CallRFFIHelperClass, "validate", "(Ljava/lang/Object;)Ljava/lang/Object;", 1);
cachedGlobalRefs = calloc(CACHED_GLOBALREFS_INITIAL_SIZE, sizeof(SEXP)); cachedGlobalRefs = calloc(CACHED_GLOBALREFS_INITIAL_SIZE, sizeof(SEXP));
cachedGlobalRefsLength = CACHED_GLOBALREFS_INITIAL_SIZE; cachedGlobalRefsLength = CACHED_GLOBALREFS_INITIAL_SIZE;
copiedVectors = calloc(COPIED_VECTORS_INITIAL_SIZE, sizeof(CopiedVector)); nativeArrayTable = calloc(NATIVE_ARRAY_TABLE_INITIAL_SIZE, sizeof(NativeArrayElem));
copiedVectorsLength = COPIED_VECTORS_INITIAL_SIZE; nativeArrayTableLength = NATIVE_ARRAY_TABLE_INITIAL_SIZE;
copiedVectorsIndex = 0; nativeArrayTableHwm = 0;
} }
const char *stringToChars(JNIEnv *jniEnv, jstring string) { const char *stringToChars(JNIEnv *jniEnv, jstring string) {
...@@ -123,18 +135,147 @@ const char *stringToChars(JNIEnv *jniEnv, jstring string) { ...@@ -123,18 +135,147 @@ const char *stringToChars(JNIEnv *jniEnv, jstring string) {
void callEnter(JNIEnv *env, jmp_buf *jmpbuf) { void callEnter(JNIEnv *env, jmp_buf *jmpbuf) {
setEnv(env); setEnv(env);
callErrorJmpBuf = jmpbuf; callErrorJmpBuf = jmpbuf;
// printf("callEnter\n"); if (callDepth >= NATIVE_ARRAY_TABLE_HWM_STACK_SIZE) {
fatalError("call stack overflow\n");
}
nativeArrayTableHwmStack[callDepth] = nativeArrayTableHwm;
callDepth++;
} }
jmp_buf *getErrorJmpBuf() { jmp_buf *getErrorJmpBuf() {
return callErrorJmpBuf; return callErrorJmpBuf;
} }
void releaseCopiedVector(JNIEnv *env, CopiedVector cv) { void callExit(JNIEnv *env) {
if (cv.obj != NULL) { int oldHwm = nativeArrayTableHwmStack[callDepth - 1];
#if TRACE_COPIES for (int i = oldHwm; i < nativeArrayTableHwm; i++) {
fprintf(traceFile, "releaseCopiedVector(%p)\n", cv.obj); releaseNativeArray(env, i);
}
nativeArrayTableHwm = oldHwm;
callDepth--;
}
void invalidateNativeArray(JNIEnv *env, SEXP oldObj) {
int i;
for (i = 0; i < nativeArrayTableHwm; i++) {
NativeArrayElem cv = nativeArrayTable[i];
if ((*env)->IsSameObject(env, cv.obj, oldObj)) {
#if TRACE_NATIVE_ARRAYS
fprintf(traceFile, "invalidateNativeArray(%p): found\n", oldObj);
#endif
releaseNativeArray(env, &cv);
nativeArrayTable[i].obj = NULL;
}
}
#if TRACE_NATIVE_ARRAYS
fprintf(traceFile, "invalidateNativeArray(%p): not found\n", oldObj);
#endif
}
static void *findNativeArray(JNIEnv *env, SEXP x) {
int i;
for (i = 0; i < nativeArrayTableHwm; i++) {
NativeArrayElem cv = nativeArrayTable[i];
if (cv.obj != NULL) {
if ((*env)->IsSameObject(env, cv.obj, x)) {
void *data = cv.data;
#if TRACE_NATIVE_ARRAYS
fprintf(traceFile, "findNativeArray(%p): found %p\n", x, data);
#endif
return data;
}
}
}
#if TRACE_NATIVE_ARRAYS
fprintf(traceFile, "findNativeArray(%p): not found\n", x);
#endif
return NULL;
}
static void addNativeArray(JNIEnv *env, SEXP x, SEXPTYPE type, void *jArray, void *data) {
#if TRACE_NATIVE_ARRAYS
fprintf(traceFile, "addNativeArray(x=%p, t=%p, ix=%d)\n", x, data, nativeArrayTableHwm);
#endif
// check for overflow
if (nativeArrayTableHwm >= nativeArrayTableLength) {
int newLength = 2 * nativeArrayTableLength;
NativeArrayElem *newnativeArrayTable = calloc(newLength, sizeof(NativeArrayElem));
if (newnativeArrayTable == NULL) {
fatalError("FFI copied vectors table expansion failure");
}
memcpy(newnativeArrayTable, nativeArrayTable, nativeArrayTableLength * sizeof(NativeArrayElem));
free(nativeArrayTable);
nativeArrayTable = newnativeArrayTable;
nativeArrayTableLength = newLength;
}
nativeArrayTable[nativeArrayTableHwm].obj = x;
nativeArrayTable[nativeArrayTableHwm].data = data;
nativeArrayTable[nativeArrayTableHwm].type = type;
nativeArrayTable[nativeArrayTableHwm].jArray = jArray;
nativeArrayTableHwm++;
}
void *getNativeArray(JNIEnv *thisenv, SEXP x, SEXPTYPE type) {
void *data = findNativeArray(thisenv, x);
jboolean isCopy;
if (data == NULL) {
jarray jArray;
switch (type) {
case INTSXP: {
jintArray intArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, INTEGER_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, intArray);
data = (*thisenv)->GetIntArrayElements(thisenv, intArray, &isCopy);
jArray = intArray;
break;
}
case REALSXP: {
jdoubleArray doubleArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, REAL_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, doubleArray);
data = (*thisenv)->GetDoubleArrayElements(thisenv, doubleArray, &isCopy);
jArray = doubleArray;
break;
}
case RAWSXP: {
jbyteArray byteArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, RAW_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, byteArray);
data = (*thisenv)->GetByteArrayElements(thisenv, byteArray, &isCopy);
jArray = byteArray;
break;
}
case LGLSXP: {
// Special treatment becuase R FFI wants int* and FastR represents using byte[]
jbyteArray byteArray = (*thisenv)->CallStaticObjectMethod(thisenv, CallRFFIHelperClass, LOGICAL_MethodID, x);
int len = (*thisenv)->GetArrayLength(thisenv, byteArray);
jbyte* internalData = (*thisenv)->GetByteArrayElements(thisenv, byteArray, &isCopy);
int* idata = malloc(len * sizeof(int));
for (int i = 0; i < len; i++) {
char value = internalData[i];
idata[i] = value == 0 ? FALSE : value == 1 ? TRUE : NA_INTEGER;
}
(*thisenv)->ReleaseByteArrayElements(thisenv, byteArray, internalData, JNI_ABORT);
jArray = byteArray;
data = idata;
break;
}
default:
fatalError("getNativeArray: unexpected type");
}
addNativeArray(thisenv, x, type, jArray, data);
}
return data;
}
static void releaseNativeArray(JNIEnv *env, int i) {
NativeArrayElem cv = nativeArrayTable[i];
#if TRACE_NATIVE_ARRAYS
fprintf(traceFile, "releaseNativeArray(x=%p, ix=%d)\n", cv.obj, i);
#endif #endif
if (cv.obj != NULL) {
switch (cv.type) { switch (cv.type) {
case INTSXP: { case INTSXP: {
jintArray intArray = (jintArray) cv.jArray; jintArray intArray = (jintArray) cv.jArray;
...@@ -152,6 +293,7 @@ void releaseCopiedVector(JNIEnv *env, CopiedVector cv) { ...@@ -152,6 +293,7 @@ void releaseCopiedVector(JNIEnv *env, CopiedVector cv) {
internalData[i] = data[i] == NA_INTEGER ? 255 : (jbyte) data[i]; internalData[i] = data[i] == NA_INTEGER ? 255 : (jbyte) data[i];
} }
(*env)->ReleaseByteArrayElements(env, byteArray, internalData, 0); (*env)->ReleaseByteArrayElements(env, byteArray, internalData, 0);
free(data); // was malloc'ed in addNativeArray
break; break;
} }
...@@ -169,78 +311,11 @@ void releaseCopiedVector(JNIEnv *env, CopiedVector cv) { ...@@ -169,78 +311,11 @@ void releaseCopiedVector(JNIEnv *env, CopiedVector cv) {
} }
default: default:
fatalError("copiedVector type"); fatalError("releaseNativeArray type");
}
}
}
void callExit(JNIEnv *env) {
// fprintf(traceFile, "callExit\n");
int i;
for (i = 0; i < copiedVectorsIndex; i++) {
releaseCopiedVector(env, copiedVectors[i]);
}
copiedVectorsIndex = 0;
}
void invalidateCopiedObject(JNIEnv *env, SEXP oldObj) {
int i;
for (i = 0; i < copiedVectorsIndex; i++) {
CopiedVector cv = copiedVectors[i];
if ((*env)->IsSameObject(env, cv.obj, oldObj)) {
#if TRACE_COPIES
fprintf(traceFile, "invalidateCopiedObject(%p): found\n", oldObj);
#endif
releaseCopiedVector(env, cv);
copiedVectors[i].obj = NULL;
}
}
#if TRACE_COPIES
fprintf(traceFile, "invalidateCopiedObject(%p): not found\n", oldObj);
#endif
}
void *findCopiedObject(JNIEnv *env, SEXP x) {
int i;
for (i = 0; i < copiedVectorsIndex; i++) {
CopiedVector cv = copiedVectors[i];
if ((*env)->IsSameObject(env, cv.obj, x)) {
void *data = cv.data;
#if TRACE_COPIES
fprintf(traceFile, "findCopiedObject(%p): found %p\n", x, data);
#endif
return data;
}
}
#if TRACE_COPIES
fprintf(traceFile, "findCopiedObject(%p): not found\n", x);
#endif
return NULL;
}
void addCopiedObject(JNIEnv *env, SEXP x, SEXPTYPE type, void *jArray, void *data) {
#if TRACE_COPIES
fprintf(traceFile, "addCopiedObject(%p, %p)\n", x, data);
#endif
if (copiedVectorsIndex >= copiedVectorsLength) {
int newLength = 2 * copiedVectorsLength;
CopiedVector *newCopiedVectors = calloc(newLength, sizeof(CopiedVector));
if (newCopiedVectors == NULL) {
fatalError("FFI copied vectors table expansion failure");
} }
memcpy(newCopiedVectors, copiedVectors, copiedVectorsLength * sizeof(CopiedVector)); // free up the slot
free(copiedVectors); cv.obj = NULL;
copiedVectors = newCopiedVectors;
copiedVectorsLength = newLength;
} }
copiedVectors[copiedVectorsIndex].obj = x;
copiedVectors[copiedVectorsIndex].data = data;
copiedVectors[copiedVectorsIndex].type = type;
copiedVectors[copiedVectorsIndex].jArray = jArray;
copiedVectorsIndex++;
#if TRACE_COPIES
fprintf(traceFile, "copiedVectorsIndex: %d\n", copiedVectorsIndex);
#endif
} }
static SEXP checkCachedGlobalRef(JNIEnv *env, SEXP obj, int useGlobal) { static SEXP checkCachedGlobalRef(JNIEnv *env, SEXP obj, int useGlobal) {
......
...@@ -62,11 +62,12 @@ void allocExit(); ...@@ -62,11 +62,12 @@ void allocExit();
jmp_buf *getErrorJmpBuf(); jmp_buf *getErrorJmpBuf();
// find an object for which we have cached the internal rep // Given the x denotes an R vector type, return a pointer to
void *findCopiedObject(JNIEnv *env, SEXP x); // the data as a C array
// add a new object to the internal rep cache void *getNativeArray(JNIEnv *env, SEXP x, SEXPTYPE type);
void addCopiedObject(JNIEnv *env, SEXP x, SEXPTYPE type, void *jArray, void *data); // Rare case where an operation changes the internal
void invalidateCopiedObject(JNIEnv *env, SEXP oldObj); // data and thus the old C array should be invalidated
void invalidateNativeArray(JNIEnv *env, SEXP oldObj);
void init_rmath(JNIEnv *env); void init_rmath(JNIEnv *env);
void init_variables(JNIEnv *env, jobjectArray initialValues); void init_variables(JNIEnv *env, jobjectArray initialValues);
...@@ -88,8 +89,8 @@ extern FILE *traceFile; ...@@ -88,8 +89,8 @@ extern FILE *traceFile;
// tracing/debugging support, set to 1 and recompile to enable // tracing/debugging support, set to 1 and recompile to enable
#define TRACE_UPCALLS 0 // trace upcalls #define TRACE_UPCALLS 0 // trace upcalls
#define TRACE_REF_CACHE 0 // trace JNI reference cache #define TRACE_REF_CACHE 0 // trace JNI reference cache
#define TRACE_COPIES 0 // trace copying of internal arrays #define TRACE_NATIVE_ARRAYS 0 // trace generation of internal arrays
#define TRACE_ENABLED TRACE_UPCALLS || TRACE_REF_CACHE || TRACE_COPIES #define TRACE_ENABLED TRACE_UPCALLS || TRACE_REF_CACHE || TRACE_NATIVE_ARRAYS
#define TARGp "%s(%p)\n" #define TARGp "%s(%p)\n"
#define TARGpp "%s(%p, %p)\n" #define TARGpp "%s(%p, %p)\n"
...@@ -111,4 +112,11 @@ extern FILE *traceFile; ...@@ -111,4 +112,11 @@ extern FILE *traceFile;
// convert a string into a char* // convert a string into a char*
const char *stringToChars(JNIEnv *jniEnv, jstring string); const char *stringToChars(JNIEnv *jniEnv, jstring string);
extern jmethodID INTEGER_MethodID;
extern jmethodID LOGICAL_MethodID;
extern jmethodID REAL_MethodID;
extern jmethodID RAW_MethodID;
extern int callDepth;
#endif /* RFFIUTILS_H */ #endif /* RFFIUTILS_H */
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment