Skip to content
Snippets Groups Projects
Commit 54c1b7ac authored by Lukas Stadler's avatar Lukas Stadler Committed by stepan
Browse files

prototype handling of character vectors in NFI .C

parent f27c342c
No related branches found
No related tags found
No related merge requests found
...@@ -22,25 +22,147 @@ ...@@ -22,25 +22,147 @@
*/ */
package com.oracle.truffle.r.ffi.impl.nfi; package com.oracle.truffle.r.ffi.impl.nfi;
import java.nio.charset.StandardCharsets;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.CanResolve;
import com.oracle.truffle.api.interop.ForeignAccess; import com.oracle.truffle.api.interop.ForeignAccess;
import com.oracle.truffle.api.interop.InteropException; import com.oracle.truffle.api.interop.InteropException;
import com.oracle.truffle.api.interop.Message; import com.oracle.truffle.api.interop.Message;
import com.oracle.truffle.api.interop.MessageResolution;
import com.oracle.truffle.api.interop.Resolve;
import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.interop.java.JavaInterop; import com.oracle.truffle.api.interop.java.JavaInterop;
import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.r.ffi.impl.interop.UnsafeAdapter;
import com.oracle.truffle.r.ffi.impl.nfi.TruffleNFI_CFactory.TruffleNFI_InvokeCNodeGen; import com.oracle.truffle.r.ffi.impl.nfi.TruffleNFI_CFactory.TruffleNFI_InvokeCNodeGen;
import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RInternalError;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.ffi.CRFFI; import com.oracle.truffle.r.runtime.ffi.CRFFI;
import com.oracle.truffle.r.runtime.ffi.NativeCallInfo; import com.oracle.truffle.r.runtime.ffi.NativeCallInfo;
import sun.misc.Unsafe;
public class TruffleNFI_C implements CRFFI { public class TruffleNFI_C implements CRFFI {
@MessageResolution(receiverType = StringWrapper.class)
public static class StringWrapperMR {
@Resolve(message = "IS_POINTER")
public abstract static class StringWrapperNativeIsPointerNode extends Node {
protected Object access(@SuppressWarnings("unused") StringWrapper receiver) {
return true;
}
}
@Resolve(message = "AS_POINTER")
public abstract static class StringWrapperNativeAsPointerNode extends Node {
protected Object access(StringWrapper receiver) {
return receiver.asPointer();
}
}
@CanResolve
public abstract static class StringWrapperCheck extends Node {
protected static boolean test(TruffleObject receiver) {
return receiver instanceof StringWrapper;
}
}
}
public static final class StringWrapper implements TruffleObject {
private final RAbstractStringVector vector;
private long address;
public StringWrapper(RAbstractStringVector vector) {
this.vector = vector;
}
@Override
public ForeignAccess getForeignAccess() {
return StringWrapperMRForeign.ACCESS;
}
public long asPointer() {
if (address == 0) {
address = allocate();
}
return address;
}
@TruffleBoundary
private long allocate() {
int length = vector.getLength();
int size = length * 8;
byte[][] bytes = new byte[length][];
for (int i = 0; i < length; i++) {
String element = vector.getDataAt(i);
bytes[i] = element.getBytes(StandardCharsets.US_ASCII);
size += bytes[i].length + 1;
}
long memory = UnsafeAdapter.UNSAFE.allocateMemory(size);
long ptr = memory + length * 8; // start of the actual character data
for (int i = 0; i < length; i++) {
UnsafeAdapter.UNSAFE.putLong(memory + i * 8, ptr);
UnsafeAdapter.UNSAFE.copyMemory(bytes[i], Unsafe.ARRAY_BYTE_BASE_OFFSET, null, ptr, bytes[i].length);
ptr += bytes[i].length;
UnsafeAdapter.UNSAFE.putByte(ptr++, (byte) 0);
}
assert ptr == memory + size : "should have filled everything";
return memory;
}
public RAbstractStringVector copyBack(RAbstractStringVector original) {
if (address == 0) {
return original;
} else {
RStringVector result = original.materialize();
String[] data = result.isTemporary() ? result.getDataWithoutCopying() : result.getDataCopy();
for (int i = 0; i < data.length; i++) {
long ptr = UnsafeAdapter.UNSAFE.getLong(address + i * 8);
int length = 0;
while (UnsafeAdapter.UNSAFE.getByte(ptr + length) != 0) {
length++;
}
byte[] bytes = new byte[length];
UnsafeAdapter.UNSAFE.copyMemory(null, ptr, bytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, length);
data[i] = new String(bytes, StandardCharsets.US_ASCII);
}
UnsafeAdapter.UNSAFE.freeMemory(address);
return RDataFactory.createStringVector(data, true);
}
}
}
abstract static class TruffleNFI_InvokeCNode extends InvokeCNode { abstract static class TruffleNFI_InvokeCNode extends InvokeCNode {
@Child private Node bindNode = Message.createInvoke(1).createNode(); @Child private Node bindNode = Message.createInvoke(1).createNode();
@Override
protected Object getNativeArgument(int index, ArgumentType type, RAbstractAtomicVector vector) {
if (type == ArgumentType.VECTOR_STRING) {
return new StringWrapper((RAbstractStringVector) vector);
} else {
return super.getNativeArgument(index, type, vector);
}
}
@Override
protected Object postProcessArgument(ArgumentType type, RAbstractAtomicVector vector, Object nativeArgument) {
if (type == ArgumentType.VECTOR_STRING) {
return ((StringWrapper) nativeArgument).copyBack((RAbstractStringVector) vector);
} else {
return super.postProcessArgument(type, vector, nativeArgument);
}
}
@Specialization(guards = "args.length == 0") @Specialization(guards = "args.length == 0")
protected void invokeCall0(NativeCallInfo nativeCallInfo, @SuppressWarnings("unused") Object[] args, @SuppressWarnings("unused") boolean hasStrings, protected void invokeCall0(NativeCallInfo nativeCallInfo, @SuppressWarnings("unused") Object[] args, @SuppressWarnings("unused") boolean hasStrings,
@Cached("createExecute(args.length)") Node executeNode) { @Cached("createExecute(args.length)") Node executeNode) {
...@@ -84,8 +206,8 @@ public class TruffleNFI_C implements CRFFI { ...@@ -84,8 +206,8 @@ public class TruffleNFI_C implements CRFFI {
sb.append("[sint32]"); sb.append("[sint32]");
} else if (arg instanceof double[]) { } else if (arg instanceof double[]) {
sb.append("[double]"); sb.append("[double]");
} else if (arg instanceof byte[][]) { } else if (arg instanceof StringWrapper) {
sb.append("[pointer]"); sb.append("pointer");
} else { } else {
throw RInternalError.unimplemented(".C type: " + arg.getClass().getSimpleName()); throw RInternalError.unimplemented(".C type: " + arg.getClass().getSimpleName());
} }
......
...@@ -40,7 +40,6 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; ...@@ -40,7 +40,6 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.nodes.RBaseNode; import com.oracle.truffle.r.runtime.nodes.RBaseNode;
/** /**
...@@ -50,7 +49,7 @@ public interface CRFFI { ...@@ -50,7 +49,7 @@ public interface CRFFI {
public static abstract class InvokeCNode extends RBaseNode { public static abstract class InvokeCNode extends RBaseNode {
enum ArgumentType { public enum ArgumentType {
VECTOR_DOUBLE, VECTOR_DOUBLE,
VECTOR_INT, VECTOR_INT,
VECTOR_LOGICAL, VECTOR_LOGICAL,
...@@ -71,7 +70,7 @@ public interface CRFFI { ...@@ -71,7 +70,7 @@ public interface CRFFI {
protected abstract void execute(NativeCallInfo nativeCallInfo, Object[] args, boolean hasStrings); protected abstract void execute(NativeCallInfo nativeCallInfo, Object[] args, boolean hasStrings);
@TruffleBoundary @TruffleBoundary
private Object getNativeArgument(int index, ArgumentType type, RAbstractAtomicVector vector) { protected Object getNativeArgument(int index, ArgumentType type, RAbstractAtomicVector vector) {
CompilerAsserts.neverPartOfCompilation(); CompilerAsserts.neverPartOfCompilation();
switch (type) { switch (type) {
case VECTOR_DOUBLE: { case VECTOR_DOUBLE: {
...@@ -121,7 +120,7 @@ public interface CRFFI { ...@@ -121,7 +120,7 @@ public interface CRFFI {
} }
} }
protected Object[] getNativeArguments(Object[] array, ArgumentType[] argTypes) { private Object[] getNativeArguments(Object[] array, ArgumentType[] argTypes) {
Object[] nativeArgs = new Object[array.length]; Object[] nativeArgs = new Object[array.length];
for (int i = 0; i < array.length; i++) { for (int i = 0; i < array.length; i++) {
nativeArgs[i] = getNativeArgument(i, argTypes[i], (RAbstractAtomicVector) array[i]); nativeArgs[i] = getNativeArgument(i, argTypes[i], (RAbstractAtomicVector) array[i]);
...@@ -130,7 +129,7 @@ public interface CRFFI { ...@@ -130,7 +129,7 @@ public interface CRFFI {
} }
@TruffleBoundary @TruffleBoundary
private static Object postProcessArgument(ArgumentType type, RAbstractAtomicVector vector, Object nativeArgument) { protected Object postProcessArgument(ArgumentType type, RAbstractAtomicVector vector, Object nativeArgument) {
switch (type) { switch (type) {
case VECTOR_STRING: case VECTOR_STRING:
return ((RAbstractStringVector) vector).materialize().copyResetData(decodeStrings((byte[][]) nativeArgument)); return ((RAbstractStringVector) vector).materialize().copyResetData(decodeStrings((byte[][]) nativeArgument));
...@@ -151,7 +150,7 @@ public interface CRFFI { ...@@ -151,7 +150,7 @@ public interface CRFFI {
} }
} }
protected Object[] postProcessArguments(Object[] array, ArgumentType[] argTypes, Object[] nativeArgs) { private Object[] postProcessArguments(Object[] array, ArgumentType[] argTypes, Object[] nativeArgs) {
Object[] results = new Object[array.length]; Object[] results = new Object[array.length];
for (int i = 0; i < array.length; i++) { for (int i = 0; i < array.length; i++) {
results[i] = postProcessArgument(argTypes[i], (RAbstractAtomicVector) array[i], nativeArgs[i]); results[i] = postProcessArgument(argTypes[i], (RAbstractAtomicVector) array[i], nativeArgs[i]);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment