Skip to content
Snippets Groups Projects
Commit ebc0383d authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert raw and connection functions to VectorAccess

parent 5b5b5409
No related branches found
No related tags found
No related merge requests found
......@@ -59,6 +59,7 @@ import java.nio.ShortBuffer;
import java.nio.channels.ByteChannel;
import java.nio.charset.Charset;
import java.nio.charset.IllegalCharsetNameException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
......@@ -90,7 +91,6 @@ import com.oracle.truffle.r.runtime.conn.SocketConnections.RSocketConnection;
import com.oracle.truffle.r.runtime.conn.TextConnections.TextRConnection;
import com.oracle.truffle.r.runtime.conn.URLConnections.URLRConnection;
import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RComplexVector;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
......@@ -106,13 +106,12 @@ import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.RTypes;
import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.env.REnvironment;
import com.oracle.truffle.r.runtime.nodes.RBaseNode;
......@@ -1066,90 +1065,83 @@ public abstract class ConnectionFunctions {
return WriteDataNodeGen.create();
}
@TruffleBoundary
private static ByteBuffer allocate(int capacity, boolean swap) {
ByteBuffer buffer = ByteBuffer.allocate(capacity);
checkOrder(buffer, swap);
return buffer;
}
@Specialization
protected ByteBuffer writeInteger(RAbstractIntVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
ByteBuffer buffer = allocate(4 * length, swap);
for (int i = 0; i < length; i++) {
int value = object.getDataAt(i);
buffer.putInt(value);
}
return buffer;
}
@Specialization
protected ByteBuffer writeDouble(RAbstractDoubleVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
ByteBuffer buffer = allocate(8 * length, swap);
for (int i = 0; i < length; i++) {
double value = object.getDataAt(i);
buffer.putDouble(value);
}
return buffer;
}
@Specialization
protected ByteBuffer writeComplex(RAbstractComplexVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
ByteBuffer buffer = allocate(16 * length, swap);
for (int i = 0; i < length; i++) {
RComplex complex = object.getDataAt(i);
double re = complex.getRealPart();
double im = complex.getImaginaryPart();
buffer.putDouble(re);
buffer.putDouble(im);
}
return buffer;
}
@Specialization
@TruffleBoundary
protected ByteBuffer writeString(RAbstractStringVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
byte[][] data = new byte[length][];
int totalLength = 0;
for (int i = 0; i < length; i++) {
String s = object.getDataAt(i);
// There is no special encoding for NA_character_
data[i] = s.getBytes();
totalLength = totalLength + data[i].length + 1; // zero pad
}
ByteBuffer buffer = allocate(totalLength, swap);
for (int i = 0; i < length; i++) {
buffer.put(data[i]);
buffer.put((byte) 0);
}
return buffer;
}
private static byte[] encodeString(String s) {
return s.getBytes(StandardCharsets.UTF_8);
}
@Specialization(guards = "objectAccess.supports(object)")
protected ByteBuffer write(RAbstractVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes,
@Cached("object.access()") VectorAccess objectAccess) {
try (SequentialIterator iter = objectAccess.access(object)) {
int length = objectAccess.getLength(iter);
ByteBuffer buffer;
switch (objectAccess.getType()) {
case Integer:
buffer = allocate(4 * length, swap);
while (objectAccess.next(iter)) {
buffer.putInt(objectAccess.getInt(iter));
}
return buffer;
case Double:
buffer = allocate(8 * length, swap);
while (objectAccess.next(iter)) {
buffer.putDouble(objectAccess.getDouble(iter));
}
return buffer;
case Complex:
buffer = allocate(16 * length, swap);
while (objectAccess.next(iter)) {
buffer.putDouble(objectAccess.getComplexR(iter));
buffer.putDouble(objectAccess.getComplexI(iter));
}
return buffer;
case Character:
byte[][] data = new byte[length][];
int totalLength = 0;
while (objectAccess.next(iter)) {
// There is no special encoding for NA_character_
data[iter.getIndex()] = encodeString(objectAccess.getString(iter));
// zero pad
totalLength = totalLength + data[iter.getIndex()].length + 1;
}
@Specialization
protected ByteBuffer writeLogical(RAbstractLogicalVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
// encoded as ints, with FALSE=0, TRUE=1, NA=Integer_NA_
int length = object.getLength();
ByteBuffer buffer = allocate(4 * length, swap);
for (int i = 0; i < length; i++) {
byte value = object.getDataAt(i);
int encoded = RRuntime.isNA(value) ? RRuntime.INT_NA : value == RRuntime.LOGICAL_FALSE ? 0 : 1;
buffer.putInt(encoded);
buffer = allocate(totalLength, swap);
for (int i = 0; i < length; i++) {
buffer.put(data[i]);
buffer.put((byte) 0);
}
return buffer;
case Logical:
buffer = allocate(4 * length, swap);
while (objectAccess.next(iter)) {
buffer.putInt(objectAccess.getInt(iter)); // converted to int
}
return buffer;
case Raw:
buffer = allocate(length, swap);
while (objectAccess.next(iter)) {
buffer.put(objectAccess.getRaw(iter));
}
return buffer;
default:
throw RInternalError.shouldNotReachHere();
}
}
return buffer;
}
@Specialization
protected ByteBuffer writeRaw(RAbstractRawVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
ByteBuffer buffer = allocate(length, swap);
for (int i = 0; i < length; i++) {
buffer.put(object.getRawDataAt(i));
}
return buffer;
@Specialization(replaces = "write")
@TruffleBoundary
protected ByteBuffer writeGeneric(RAbstractVector object, int size, boolean swap, boolean useBytes) {
return write(object, size, swap, useBytes, object.slowPathAccess());
}
}
......
......@@ -40,9 +40,11 @@ import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RRawVector;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
/**
* Conversion and manipulation of objects of type "raw".
......@@ -57,17 +59,26 @@ public class RawFunctions {
casts.arg("x").defaultError(RError.Message.ARG_MUST_BE_CHARACTER_VECTOR_LENGTH_ONE).mustBe(stringValue()).asStringVector().mustBe(notEmpty());
}
@Specialization
protected RRawVector charToRaw(RAbstractStringVector x) {
if (x.getLength() > 1) {
warning(RError.Message.ARG_SHOULD_BE_CHARACTER_VECTOR_LENGTH_ONE);
}
String s = x.getDataAt(0);
byte[] data = new byte[s.length()];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) s.charAt(i);
@Specialization(guards = "xAccess.supports(x)")
protected RRawVector charToRaw(RAbstractStringVector x,
@Cached("x.access()") VectorAccess xAccess) {
try (RandomIterator iter = xAccess.randomAccess(x)) {
if (xAccess.getLength(iter) != 1) {
warning(RError.Message.ARG_SHOULD_BE_CHARACTER_VECTOR_LENGTH_ONE);
}
String s = xAccess.getString(iter, 0);
byte[] data = new byte[s.length()];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) s.charAt(i);
}
return RDataFactory.createRawVector(data);
}
return RDataFactory.createRawVector(data);
}
@Specialization(replaces = "charToRaw")
@TruffleBoundary
protected RRawVector charToRawGeneric(RAbstractStringVector x) {
return charToRaw(x, x.slowPathAccess());
}
}
......@@ -80,28 +91,45 @@ public class RawFunctions {
casts.arg("multiple").defaultError(RError.Message.INVALID_LOGICAL).asLogicalVector().findFirst().mustNotBeNA().map(toBoolean());
}
@Specialization
@TruffleBoundary
protected RStringVector rawToChar(RAbstractRawVector x, boolean multiple) {
RStringVector result;
if (multiple) {
String[] data = new String[x.getLength()];
for (int i = 0; i < data.length; i++) {
data[i] = new String(new byte[]{x.getRawDataAt(i)});
}
result = RDataFactory.createStringVector(data, RDataFactory.COMPLETE_VECTOR);
} else {
int j = 0;
byte[] data = new byte[x.getLength()];
for (int i = 0; i < data.length; i++) {
byte b = x.getRawDataAt(i);
if (b != 0) {
data[j++] = b;
private static String createString(int j, byte[] data) {
return new String(data, 0, j);
}
@TruffleBoundary
private static String createString(byte value) {
return new String(new byte[]{value});
}
@Specialization(guards = "xAccess.supports(x)")
protected Object rawToChar(RAbstractRawVector x, boolean multiple,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
if (multiple) {
String[] data = new String[xAccess.getLength(iter)];
while (xAccess.next(iter)) {
byte value = xAccess.getRaw(iter);
data[iter.getIndex()] = createString(value);
}
return RDataFactory.createStringVector(data, RDataFactory.COMPLETE_VECTOR);
} else {
int j = 0;
byte[] data = new byte[xAccess.getLength(iter)];
while (xAccess.next(iter)) {
byte b = xAccess.getRaw(iter);
if (b != 0) {
data[j++] = b;
}
}
return createString(j, data);
}
result = RDataFactory.createStringVectorFromScalar(new String(data, 0, j));
}
return result;
}
@Specialization(replaces = "rawToChar")
@TruffleBoundary
protected Object rawToCharGeneric(RAbstractRawVector x, boolean multiple) {
return rawToChar(x, multiple, x.slowPathAccess());
}
}
......@@ -114,23 +142,30 @@ public class RawFunctions {
casts.arg("n").defaultError(RError.Message.MUST_BE_SMALL_INT, "shift").asIntegerVector().findFirst().mustNotBeNA().mustBe(gte(-8).and(lte(8)));
}
@Specialization
@Specialization(guards = "xAccess.supports(x)")
protected RRawVector rawShift(RAbstractRawVector x, int n,
@Cached("createBinaryProfile()") ConditionProfile negativeShiftProfile) {
byte[] data = new byte[x.getLength()];
if (negativeShiftProfile.profile(n < 0)) {
for (int i = 0; i < data.length; i++) {
data[i] = (byte) ((x.getRawDataAt(i) & 0xff) >> (-n));
}
} else {
for (int i = 0; i < data.length; i++) {
data[i] = (byte) (x.getRawDataAt(i) << n);
@Cached("createBinaryProfile()") ConditionProfile negativeShiftProfile,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
byte[] data = new byte[xAccess.getLength(iter)];
if (negativeShiftProfile.profile(n < 0)) {
while (xAccess.next(iter)) {
data[iter.getIndex()] = (byte) ((xAccess.getRaw(iter) & 0xff) >> (-n));
}
} else {
while (xAccess.next(iter)) {
data[iter.getIndex()] = (byte) (xAccess.getRaw(iter) << n);
}
}
return RDataFactory.createRawVector(data);
}
return RDataFactory.createRawVector(data);
}
}
// TODO the rest of the functions
@Specialization(replaces = "rawShift")
@TruffleBoundary
protected RRawVector rawShiftGeneric(RAbstractRawVector x, int n,
@Cached("createBinaryProfile()") ConditionProfile negativeShiftProfile) {
return rawShift(x, n, negativeShiftProfile, x.slowPathAccess());
}
}
}
......@@ -25,6 +25,8 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
......@@ -32,6 +34,8 @@ import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
@RBuiltin(name = "rawToBits", kind = INTERNAL, parameterNames = {"x"}, behavior = PURE)
public abstract class RawToBits extends RBuiltinNode.Arg1 {
......@@ -41,17 +45,26 @@ public abstract class RawToBits extends RBuiltinNode.Arg1 {
casts.arg("x").mustNotBeNull(RError.Message.ARGUMENT_MUST_BE_RAW_VECTOR, "x").mustBe(Predef.rawValue(), RError.Message.ARGUMENT_MUST_BE_RAW_VECTOR, "x");
}
@Specialization
protected RAbstractRawVector rawToBits(RAbstractRawVector x) {
byte[] result = new byte[8 * x.getLength()];
int pos = 0;
for (int j = 0; j < x.getLength(); j++) {
byte temp = x.getRawDataAt(j);
for (int i = 0; i < 8; i++) {
result[pos++] = (byte) (temp & 1);
temp >>= 1;
@Specialization(guards = "xAccess.supports(x)")
protected RAbstractRawVector rawToBits(RAbstractRawVector x,
@Cached("x.access()") VectorAccess xAccess) {
try (SequentialIterator iter = xAccess.access(x)) {
byte[] result = new byte[8 * x.getLength()];
int pos = 0;
while (xAccess.next(iter)) {
byte temp = xAccess.getRaw(iter);
for (int i = 0; i < 8; i++) {
result[pos++] = (byte) (temp & 1);
temp >>= 1;
}
}
return RDataFactory.createRawVector(result);
}
return RDataFactory.createRawVector(result);
}
@Specialization(replaces = "rawToBits")
@TruffleBoundary
protected RAbstractRawVector rawToBitsGeneric(RAbstractRawVector x) {
return rawToBits(x, x.slowPathAccess());
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment