Skip to content
Snippets Groups Projects
Commit ad659f25 authored by Lukas Stadler's avatar Lukas Stadler Committed by stepan
Browse files

allow RIntVector and RLogicalVector contents to live in native space

parent e4fb8c37
No related branches found
No related tags found
No related merge requests found
......@@ -22,17 +22,33 @@
*/
package com.oracle.truffle.r.ffi.impl.nfi;
import static com.oracle.truffle.r.ffi.impl.common.RFFIUtils.guaranteeInstanceOf;
import java.nio.charset.StandardCharsets;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.interop.CanResolve;
import com.oracle.truffle.api.interop.ForeignAccess;
import com.oracle.truffle.api.interop.InteropException;
import com.oracle.truffle.api.interop.Message;
import com.oracle.truffle.api.interop.MessageResolution;
import com.oracle.truffle.api.interop.Resolve;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.r.ffi.impl.common.JavaUpCallsRFFIImpl;
import com.oracle.truffle.r.ffi.impl.common.RFFIUtils;
import com.oracle.truffle.r.ffi.impl.interop.UnsafeAdapter;
import com.oracle.truffle.r.ffi.impl.nfi.TruffleNFI_C.StringWrapper;
import com.oracle.truffle.r.ffi.impl.upcalls.FFIUnwrapNode;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RLogicalVector;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.ffi.CharSXPWrapper;
import com.oracle.truffle.r.runtime.ffi.DLL;
import com.oracle.truffle.r.runtime.ffi.DLL.CEntry;
......@@ -41,6 +57,8 @@ import com.oracle.truffle.r.runtime.ffi.DLL.DotSymbol;
import com.oracle.truffle.r.runtime.ffi.DLL.SymbolHandle;
import com.oracle.truffle.r.runtime.gnur.SEXPTYPE;
import sun.misc.Unsafe;
public class TruffleNFI_UpCallsRFFIImpl extends JavaUpCallsRFFIImpl {
private static final String SETSYMBOL_SIGNATURE = "(pointer, sint32, pointer, sint32): pointer";
......@@ -71,35 +89,69 @@ public class TruffleNFI_UpCallsRFFIImpl extends JavaUpCallsRFFIImpl {
return CharSXPWrapper.create(TruffleNFI_Utils.getString(address, len));
}
@MessageResolution(receiverType = VectorWrapper.class)
public static class VectorWrapperMR {
@Resolve(message = "IS_POINTER")
public abstract static class IntVectorWrapperNativeIsPointerNode extends Node {
protected Object access(@SuppressWarnings("unused") VectorWrapper receiver) {
return true;
}
}
@Resolve(message = "AS_POINTER")
public abstract static class IntVectorWrapperNativeAsPointerNode extends Node {
protected long access(VectorWrapper receiver) {
RVector<?> v = receiver.vector;
if (v instanceof RIntVector) {
return ((RIntVector) v).allocateNativeContents();
} else if (v instanceof RLogicalVector) {
return ((RLogicalVector) v).allocateNativeContents();
} else {
throw RInternalError.shouldNotReachHere();
}
}
}
@CanResolve
public abstract static class VectorWrapperCheck extends Node {
protected static boolean test(TruffleObject receiver) {
return receiver instanceof VectorWrapper;
}
}
}
public static final class VectorWrapper implements TruffleObject {
private final RVector<?> vector;
public VectorWrapper(RVector<?> vector) {
this.vector = vector;
}
@Override
public ForeignAccess getForeignAccess() {
return VectorWrapperMRForeign.ACCESS;
}
}
@Override
public Object INTEGER(Object x) {
long arrayAddress = TruffleNFI_NativeArray.findArray(x);
if (arrayAddress == 0) {
Object array = super.INTEGER(x);
arrayAddress = TruffleNFI_NativeArray.recordArray(x, array, SEXPTYPE.INTSXP);
} else {
TruffleNFI_Call.returnArrayExisting(SEXPTYPE.INTSXP, arrayAddress);
}
return x;
// also handles LOGICAL
assert x instanceof RIntVector || x instanceof RLogicalVector;
return new VectorWrapper(guaranteeInstanceOf(x, RVector.class));
}
@Override
public Object LOGICAL(Object x) {
long arrayAddress = TruffleNFI_NativeArray.findArray(x);
if (arrayAddress == 0) {
Object array = super.LOGICAL(x);
arrayAddress = TruffleNFI_NativeArray.recordArray(x, array, SEXPTYPE.LGLSXP);
} else {
TruffleNFI_Call.returnArrayExisting(SEXPTYPE.LGLSXP, arrayAddress);
}
return x;
return new VectorWrapper(guaranteeInstanceOf(x, RLogicalVector.class));
}
@Override
public Object REAL(Object x) {
long arrayAddress = TruffleNFI_NativeArray.findArray(x);
if (arrayAddress == 0) {
System.out.println("getting REAL contents");
Object array = super.REAL(x);
arrayAddress = TruffleNFI_NativeArray.recordArray(x, array, SEXPTYPE.REALSXP);
} else {
......@@ -113,6 +165,7 @@ public class TruffleNFI_UpCallsRFFIImpl extends JavaUpCallsRFFIImpl {
public Object RAW(Object x) {
long arrayAddress = TruffleNFI_NativeArray.findArray(x);
if (arrayAddress == 0) {
System.out.println("getting RAW contents");
Object array = super.RAW(x);
arrayAddress = TruffleNFI_NativeArray.recordArray(x, array, SEXPTYPE.RAWSXP);
} else {
......@@ -125,6 +178,7 @@ public class TruffleNFI_UpCallsRFFIImpl extends JavaUpCallsRFFIImpl {
public Object R_CHAR(Object x) {
long arrayAddress = TruffleNFI_NativeArray.findArray(x);
if (arrayAddress == 0) {
System.out.println("getting R_CHAR contents");
CharSXPWrapper charSXP = (CharSXPWrapper) x;
Object array = charSXP.getContents().getBytes();
arrayAddress = TruffleNFI_NativeArray.recordArray(x, array, SEXPTYPE.CHARSXP);
......
......@@ -25,6 +25,7 @@ package com.oracle.truffle.r.ffi.impl.upcalls;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.r.ffi.impl.nfi.TruffleNFI_UpCallsRFFIImpl.VectorWrapper;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDouble;
......@@ -93,6 +94,11 @@ public abstract class FFIWrapNode extends Node {
return value;
}
@Specialization
protected static Object wrap(VectorWrapper value) {
return value;
}
@Fallback
protected static Object wrap(Object value) {
System.out.println("invalid wrapping: " + value.getClass().getSimpleName());
......
......@@ -98,13 +98,11 @@ void return_FREE(void *address) {
}
int *INTEGER(SEXP x) {
((call_INTEGER) callbacks[INTEGER_x])(x);
return return_int;
return ((call_INTEGER) callbacks[INTEGER_x])(x);
}
int *LOGICAL(SEXP x){
((call_LOGICAL) callbacks[LOGICAL_x])(x);
return return_int;
return ((call_LOGICAL) callbacks[LOGICAL_x])(x);
}
double *REAL(SEXP x){
......
......@@ -37,6 +37,7 @@ import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.api.source.SourceSection;
import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
......@@ -79,11 +80,19 @@ public final class NativeDataAccess {
private static final class NativeMirror {
private final long id;
private long dataAddress;
private long length;
NativeMirror() {
this.id = counter.incrementAndGet();
}
void allocateNative(Object source, long sourceLength, int len, int elementBase, int elementSize) {
assert dataAddress == 0;
dataAddress = UnsafeAdapter.UNSAFE.allocateMemory(sourceLength * elementSize);
UnsafeAdapter.UNSAFE.copyMemory(source, elementBase, null, dataAddress, sourceLength * elementSize);
this.length = len;
}
@Override
protected void finalize() throws Throwable {
super.finalize();
......@@ -148,4 +157,72 @@ public final class NativeDataAccess {
}
return result;
}
static long getDataLength(RVector<?> vector) {
return ((NativeMirror) vector.getNativeMirror()).length;
}
static int getIntData(RVector<?> vector, int index) {
long address = ((NativeMirror) vector.getNativeMirror()).dataAddress;
assert address != 0;
return UnsafeAdapter.UNSAFE.getInt(address + index * Unsafe.ARRAY_INT_INDEX_SCALE);
}
static void setIntData(RVector<?> vector, int index, int value) {
long address = ((NativeMirror) vector.getNativeMirror()).dataAddress;
assert address != 0;
UnsafeAdapter.UNSAFE.putInt(address + index * Unsafe.ARRAY_INT_INDEX_SCALE, value);
}
static double getDoubleData(RVector<?> vector, int index) {
long address = ((NativeMirror) vector.getNativeMirror()).dataAddress;
assert address != 0;
return UnsafeAdapter.UNSAFE.getDouble(address + index * Unsafe.ARRAY_INT_INDEX_SCALE);
}
static void setDoubleData(RVector<?> vector, int index, double value) {
long address = ((NativeMirror) vector.getNativeMirror()).dataAddress;
assert address != 0;
UnsafeAdapter.UNSAFE.putDouble(address + index * Unsafe.ARRAY_INT_INDEX_SCALE, value);
}
static long allocateNativeContents(RLogicalVector vector, byte[] data, int length) {
NativeMirror mirror = (NativeMirror) vector.getNativeMirror();
assert mirror != null;
assert mirror.dataAddress == 0 ^ data == null;
if (mirror.dataAddress == 0) {
// System.out.println(String.format("allocating native for logical vector %16x",
// mirror.id));
int[] intArray = new int[data.length];
for (int i = 0; i < data.length; i++) {
intArray[i] = RRuntime.logical2int(data[i]);
}
((NativeMirror) vector.getNativeMirror()).allocateNative(intArray, data.length, length, Unsafe.ARRAY_INT_BASE_OFFSET, Unsafe.ARRAY_INT_INDEX_SCALE);
}
return mirror.dataAddress;
}
static long allocateNativeContents(RIntVector vector, int[] data, int length) {
NativeMirror mirror = (NativeMirror) vector.getNativeMirror();
assert mirror != null;
assert mirror.dataAddress == 0 ^ data == null;
if (mirror.dataAddress == 0) {
// System.out.println(String.format("allocating native for int vector %16x",
// mirror.id));
((NativeMirror) vector.getNativeMirror()).allocateNative(data, data.length, length, Unsafe.ARRAY_INT_BASE_OFFSET, Unsafe.ARRAY_INT_INDEX_SCALE);
}
return mirror.dataAddress;
}
static long allocateNativeContents(RDoubleVector vector, double[] data, int length) {
NativeMirror mirror = (NativeMirror) vector.getNativeMirror();
assert mirror != null;
assert mirror.dataAddress == 0 ^ data == null;
if (mirror.dataAddress == 0) {
// System.out.println(String.format("allocating native for double vector %16x",
// mirror.id));
((NativeMirror) vector.getNativeMirror()).allocateNative(data, data.length, length, Unsafe.ARRAY_DOUBLE_BASE_OFFSET, Unsafe.ARRAY_DOUBLE_INDEX_SCALE);
}
return mirror.dataAddress;
}
}
......@@ -35,7 +35,7 @@ import com.oracle.truffle.r.runtime.ops.na.NACheck;
public final class RIntVector extends RVector<int[]> implements RAbstractIntVector {
private final int[] data;
private int[] data;
RIntVector(int[] data, boolean complete) {
super(complete);
......@@ -68,24 +68,29 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
@Override
public int[] getInternalStore() {
assert data != null;
return data;
}
@Override
public int getDataAt(int index) {
return data[index];
return data == null ? NativeDataAccess.getIntData(this, index) : data[index];
}
@Override
public int getDataAt(Object store, int index) {
assert data == store;
return ((int[]) store)[index];
return store == null ? NativeDataAccess.getIntData(this, index) : ((int[]) store)[index];
}
@Override
public void setDataAt(Object store, int index, int value) {
assert data == store;
((int[]) store)[index] = value;
if (store == null) {
NativeDataAccess.setIntData(this, index, value);
} else {
((int[]) store)[index] = value;
}
}
@Override
......@@ -108,7 +113,7 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
@Override
public int getLength() {
return data.length;
return data == null ? (int) NativeDataAccess.getDataLength(this) : data.length;
}
@Override
......@@ -119,8 +124,8 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
@Override
public boolean verify() {
if (isComplete()) {
for (int x : data) {
if (x == RRuntime.INT_NA) {
for (int i = 0; i < getLength(); i++) {
if (getDataAt(i) == RRuntime.INT_NA) {
return false;
}
}
......@@ -130,6 +135,7 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
@Override
public int[] getDataCopy() {
assert data != null;
return Arrays.copyOf(data, data.length);
}
......@@ -139,6 +145,7 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
*/
@Override
public int[] getDataWithoutCopying() {
assert data != null;
return data;
}
......@@ -147,13 +154,17 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
return RDataFactory.createIntVector(data, isComplete(), newDimensions);
}
public RIntVector updateDataAt(int i, int right, NACheck valueNACheck) {
public RIntVector updateDataAt(int index, int value, NACheck valueNACheck) {
assert !this.isShared();
data[i] = right;
if (valueNACheck.check(right)) {
if (data == null) {
NativeDataAccess.setIntData(this, index, value);
} else {
data[index] = value;
}
if (valueNACheck.check(value)) {
setComplete(false);
}
assert !isComplete() || !RRuntime.isNA(right);
assert !isComplete() || !RRuntime.isNA(value);
return this;
}
......@@ -178,12 +189,14 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
}
private int[] copyResizedData(int size, boolean fillNA) {
assert data != null;
int[] newData = Arrays.copyOf(data, size);
return resizeData(newData, this.data, this.getLength(), fillNA);
}
@Override
protected RIntVector internalCopyResized(int size, boolean fillNA, int[] dimensions) {
assert data != null;
boolean isComplete = isComplete() && ((data.length >= size) || !fillNA);
return RDataFactory.createIntVector(copyResizedData(size, fillNA), isComplete, dimensions);
}
......@@ -201,6 +214,11 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
@Override
public void transferElementSameType(int toIndex, RAbstractVector fromVector, int fromIndex) {
RAbstractIntVector other = (RAbstractIntVector) fromVector;
if (data == null) {
NativeDataAccess.setIntData(this, toIndex, other.getDataAt(fromIndex));
} else {
data[toIndex] = other.getDataAt(fromIndex);
}
data[toIndex] = other.getDataAt(fromIndex);
}
......@@ -210,7 +228,19 @@ public final class RIntVector extends RVector<int[]> implements RAbstractIntVect
}
@Override
public void setElement(int i, Object value) {
data[i] = (int) value;
public void setElement(int index, Object value) {
if (data == null) {
NativeDataAccess.setIntData(this, index, (int) value);
} else {
data[index] = (int) value;
}
}
public long allocateNativeContents() {
try {
return NativeDataAccess.allocateNativeContents(this, data, getLength());
} finally {
data = null;
}
}
}
......@@ -35,7 +35,7 @@ import com.oracle.truffle.r.runtime.ops.na.NACheck;
public final class RLogicalVector extends RVector<byte[]> implements RAbstractLogicalVector {
private final byte[] data;
private byte[] data;
RLogicalVector(byte[] data, boolean complete) {
super(complete);
......@@ -70,23 +70,29 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
@Override
public byte[] getInternalStore() {
assert data != null;
return data;
}
@Override
public void setDataAt(Object store, int index, byte value) {
assert data == store;
((byte[]) store)[index] = value;
if (store == null) {
NativeDataAccess.setIntData(this, index, RRuntime.logical2int(value));
} else {
((byte[]) store)[index] = value;
}
}
@Override
public byte getDataAt(Object store, int index) {
assert data == store;
return ((byte[]) store)[index];
return data == null ? (byte) NativeDataAccess.getIntData(this, index) : data[index];
}
@Override
protected RLogicalVector internalCopy() {
assert data != null;
return new RLogicalVector(Arrays.copyOf(data, data.length), isComplete());
}
......@@ -105,7 +111,7 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
@Override
public int getLength() {
return data.length;
return data == null ? (int) NativeDataAccess.getDataLength(this) : data.length;
}
@Override
......@@ -116,8 +122,8 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
@Override
public boolean verify() {
if (isComplete()) {
for (byte b : data) {
if (b == RRuntime.LOGICAL_NA) {
for (int i = 0; i < getLength(); i++) {
if (getDataAt(i) == RRuntime.LOGICAL_NA) {
return false;
}
}
......@@ -126,17 +132,21 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
}
@Override
public byte getDataAt(int i) {
return data[i];
public byte getDataAt(int index) {
return data == null ? (byte) NativeDataAccess.getIntData(this, index) : data[index];
}
private RLogicalVector updateDataAt(int index, byte right, NACheck valueNACheck) {
private RLogicalVector updateDataAt(int index, byte value, NACheck valueNACheck) {
assert !this.isShared();
data[index] = right;
if (valueNACheck.check(right)) {
if (data == null) {
NativeDataAccess.setIntData(this, index, RRuntime.logical2int(value));
} else {
data[index] = value;
}
if (valueNACheck.check(value)) {
setComplete(false);
}
assert !isComplete() || !RRuntime.isNA(right);
assert !isComplete() || !RRuntime.isNA(value);
return this;
}
......@@ -147,6 +157,7 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
}
private byte[] copyResizedData(int size, boolean fillNA) {
assert data != null;
byte[] newData = Arrays.copyOf(data, size);
if (size > this.getLength()) {
if (fillNA) {
......@@ -164,7 +175,7 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
@Override
protected RLogicalVector internalCopyResized(int size, boolean fillNA, int[] dimensions) {
boolean isComplete = isComplete() && ((data.length >= size) || !fillNA);
boolean isComplete = isComplete() && ((getLength() >= size) || !fillNA);
return RDataFactory.createLogicalVector(copyResizedData(size, fillNA), isComplete, dimensions);
}
......@@ -176,11 +187,16 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
@Override
public void transferElementSameType(int toIndex, RAbstractVector fromVector, int fromIndex) {
RAbstractLogicalVector other = (RAbstractLogicalVector) fromVector;
data[toIndex] = other.getDataAt(fromIndex);
if (data == null) {
NativeDataAccess.setIntData(this, toIndex, RRuntime.logical2int(other.getDataAt(fromIndex)));
} else {
data[toIndex] = other.getDataAt(fromIndex);
}
}
@Override
public byte[] getDataCopy() {
assert data != null;
return Arrays.copyOf(data, data.length);
}
......@@ -190,11 +206,13 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
*/
@Override
public byte[] getDataWithoutCopying() {
assert data != null;
return data;
}
@Override
public RLogicalVector copyWithNewDimensions(int[] newDimensions) {
assert data != null;
return RDataFactory.createLogicalVector(data, isComplete(), newDimensions);
}
......@@ -207,4 +225,12 @@ public final class RLogicalVector extends RVector<byte[]> implements RAbstractLo
public Object getDataAtAsObject(int index) {
return getDataAt(index);
}
public long allocateNativeContents() {
try {
return NativeDataAccess.allocateNativeContents(this, data, getLength());
} finally {
data = null;
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment