Skip to content
Snippets Groups Projects
Commit ebc0383d authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert raw and connection functions to VectorAccess

parent 5b5b5409
No related branches found
No related tags found
No related merge requests found
...@@ -59,6 +59,7 @@ import java.nio.ShortBuffer; ...@@ -59,6 +59,7 @@ import java.nio.ShortBuffer;
import java.nio.channels.ByteChannel; import java.nio.channels.ByteChannel;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.nio.charset.IllegalCharsetNameException; import java.nio.charset.IllegalCharsetNameException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
...@@ -90,7 +91,6 @@ import com.oracle.truffle.r.runtime.conn.SocketConnections.RSocketConnection; ...@@ -90,7 +91,6 @@ import com.oracle.truffle.r.runtime.conn.SocketConnections.RSocketConnection;
import com.oracle.truffle.r.runtime.conn.TextConnections.TextRConnection; import com.oracle.truffle.r.runtime.conn.TextConnections.TextRConnection;
import com.oracle.truffle.r.runtime.conn.URLConnections.URLRConnection; import com.oracle.truffle.r.runtime.conn.URLConnections.URLRConnection;
import com.oracle.truffle.r.runtime.context.RContext; import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RComplexVector;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RDoubleVector;
...@@ -106,13 +106,12 @@ import com.oracle.truffle.r.runtime.data.RStringVector; ...@@ -106,13 +106,12 @@ import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.RTypes; import com.oracle.truffle.r.runtime.data.RTypes;
import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector; import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.env.REnvironment; import com.oracle.truffle.r.runtime.env.REnvironment;
import com.oracle.truffle.r.runtime.nodes.RBaseNode; import com.oracle.truffle.r.runtime.nodes.RBaseNode;
...@@ -1066,90 +1065,83 @@ public abstract class ConnectionFunctions { ...@@ -1066,90 +1065,83 @@ public abstract class ConnectionFunctions {
return WriteDataNodeGen.create(); return WriteDataNodeGen.create();
} }
@TruffleBoundary
private static ByteBuffer allocate(int capacity, boolean swap) { private static ByteBuffer allocate(int capacity, boolean swap) {
ByteBuffer buffer = ByteBuffer.allocate(capacity); ByteBuffer buffer = ByteBuffer.allocate(capacity);
checkOrder(buffer, swap); checkOrder(buffer, swap);
return buffer; return buffer;
} }
@Specialization
protected ByteBuffer writeInteger(RAbstractIntVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
ByteBuffer buffer = allocate(4 * length, swap);
for (int i = 0; i < length; i++) {
int value = object.getDataAt(i);
buffer.putInt(value);
}
return buffer;
}
@Specialization
protected ByteBuffer writeDouble(RAbstractDoubleVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
ByteBuffer buffer = allocate(8 * length, swap);
for (int i = 0; i < length; i++) {
double value = object.getDataAt(i);
buffer.putDouble(value);
}
return buffer;
}
@Specialization
protected ByteBuffer writeComplex(RAbstractComplexVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) {
int length = object.getLength();
ByteBuffer buffer = allocate(16 * length, swap);
for (int i = 0; i < length; i++) {
RComplex complex = object.getDataAt(i);
double re = complex.getRealPart();
double im = complex.getImaginaryPart();
buffer.putDouble(re);
buffer.putDouble(im);
}
return buffer;
}
@Specialization
@TruffleBoundary @TruffleBoundary
protected ByteBuffer writeString(RAbstractStringVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) { private static byte[] encodeString(String s) {
int length = object.getLength(); return s.getBytes(StandardCharsets.UTF_8);
byte[][] data = new byte[length][]; }
int totalLength = 0;
for (int i = 0; i < length; i++) { @Specialization(guards = "objectAccess.supports(object)")
String s = object.getDataAt(i); protected ByteBuffer write(RAbstractVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes,
// There is no special encoding for NA_character_ @Cached("object.access()") VectorAccess objectAccess) {
data[i] = s.getBytes(); try (SequentialIterator iter = objectAccess.access(object)) {
totalLength = totalLength + data[i].length + 1; // zero pad int length = objectAccess.getLength(iter);
}
ByteBuffer buffer;
ByteBuffer buffer = allocate(totalLength, swap); switch (objectAccess.getType()) {
for (int i = 0; i < length; i++) { case Integer:
buffer.put(data[i]); buffer = allocate(4 * length, swap);
buffer.put((byte) 0); while (objectAccess.next(iter)) {
} buffer.putInt(objectAccess.getInt(iter));
return buffer; }
} return buffer;
case Double:
buffer = allocate(8 * length, swap);
while (objectAccess.next(iter)) {
buffer.putDouble(objectAccess.getDouble(iter));
}
return buffer;
case Complex:
buffer = allocate(16 * length, swap);
while (objectAccess.next(iter)) {
buffer.putDouble(objectAccess.getComplexR(iter));
buffer.putDouble(objectAccess.getComplexI(iter));
}
return buffer;
case Character:
byte[][] data = new byte[length][];
int totalLength = 0;
while (objectAccess.next(iter)) {
// There is no special encoding for NA_character_
data[iter.getIndex()] = encodeString(objectAccess.getString(iter));
// zero pad
totalLength = totalLength + data[iter.getIndex()].length + 1;
}
@Specialization buffer = allocate(totalLength, swap);
protected ByteBuffer writeLogical(RAbstractLogicalVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) { for (int i = 0; i < length; i++) {
// encoded as ints, with FALSE=0, TRUE=1, NA=Integer_NA_ buffer.put(data[i]);
int length = object.getLength(); buffer.put((byte) 0);
ByteBuffer buffer = allocate(4 * length, swap); }
for (int i = 0; i < length; i++) { return buffer;
byte value = object.getDataAt(i); case Logical:
int encoded = RRuntime.isNA(value) ? RRuntime.INT_NA : value == RRuntime.LOGICAL_FALSE ? 0 : 1; buffer = allocate(4 * length, swap);
buffer.putInt(encoded); while (objectAccess.next(iter)) {
buffer.putInt(objectAccess.getInt(iter)); // converted to int
}
return buffer;
case Raw:
buffer = allocate(length, swap);
while (objectAccess.next(iter)) {
buffer.put(objectAccess.getRaw(iter));
}
return buffer;
default:
throw RInternalError.shouldNotReachHere();
}
} }
return buffer;
} }
@Specialization @Specialization(replaces = "write")
protected ByteBuffer writeRaw(RAbstractRawVector object, @SuppressWarnings("unused") int size, boolean swap, @SuppressWarnings("unused") boolean useBytes) { @TruffleBoundary
int length = object.getLength(); protected ByteBuffer writeGeneric(RAbstractVector object, int size, boolean swap, boolean useBytes) {
ByteBuffer buffer = allocate(length, swap); return write(object, size, swap, useBytes, object.slowPathAccess());
for (int i = 0; i < length; i++) {
buffer.put(object.getRawDataAt(i));
}
return buffer;
} }
} }
......
...@@ -40,9 +40,11 @@ import com.oracle.truffle.r.runtime.RError; ...@@ -40,9 +40,11 @@ import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RRawVector;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
/** /**
* Conversion and manipulation of objects of type "raw". * Conversion and manipulation of objects of type "raw".
...@@ -57,17 +59,26 @@ public class RawFunctions { ...@@ -57,17 +59,26 @@ public class RawFunctions {
casts.arg("x").defaultError(RError.Message.ARG_MUST_BE_CHARACTER_VECTOR_LENGTH_ONE).mustBe(stringValue()).asStringVector().mustBe(notEmpty()); casts.arg("x").defaultError(RError.Message.ARG_MUST_BE_CHARACTER_VECTOR_LENGTH_ONE).mustBe(stringValue()).asStringVector().mustBe(notEmpty());
} }
@Specialization @Specialization(guards = "xAccess.supports(x)")
protected RRawVector charToRaw(RAbstractStringVector x) { protected RRawVector charToRaw(RAbstractStringVector x,
if (x.getLength() > 1) { @Cached("x.access()") VectorAccess xAccess) {
warning(RError.Message.ARG_SHOULD_BE_CHARACTER_VECTOR_LENGTH_ONE); try (RandomIterator iter = xAccess.randomAccess(x)) {
} if (xAccess.getLength(iter) != 1) {
String s = x.getDataAt(0); warning(RError.Message.ARG_SHOULD_BE_CHARACTER_VECTOR_LENGTH_ONE);
byte[] data = new byte[s.length()]; }
for (int i = 0; i < data.length; i++) { String s = xAccess.getString(iter, 0);
data[i] = (byte) s.charAt(i); byte[] data = new byte[s.length()];
for (int i = 0; i < data.length; i++) {
data[i] = (byte) s.charAt(i);
}
return RDataFactory.createRawVector(data);
} }
return RDataFactory.createRawVector(data); }
@Specialization(replaces = "charToRaw")
@TruffleBoundary
protected RRawVector charToRawGeneric(RAbstractStringVector x) {
return charToRaw(x, x.slowPathAccess());
} }
} }
...@@ -80,28 +91,45 @@ public class RawFunctions { ...@@ -80,28 +91,45 @@ public class RawFunctions {
casts.arg("multiple").defaultError(RError.Message.INVALID_LOGICAL).asLogicalVector().findFirst().mustNotBeNA().map(toBoolean()); casts.arg("multiple").defaultError(RError.Message.INVALID_LOGICAL).asLogicalVector().findFirst().mustNotBeNA().map(toBoolean());
} }
@Specialization
@TruffleBoundary @TruffleBoundary
protected RStringVector rawToChar(RAbstractRawVector x, boolean multiple) { private static String createString(int j, byte[] data) {
RStringVector result; return new String(data, 0, j);
if (multiple) { }
String[] data = new String[x.getLength()];
for (int i = 0; i < data.length; i++) { @TruffleBoundary
data[i] = new String(new byte[]{x.getRawDataAt(i)}); private static String createString(byte value) {
} return new String(new byte[]{value});
result = RDataFactory.createStringVector(data, RDataFactory.COMPLETE_VECTOR); }
} else {
int j = 0; @Specialization(guards = "xAccess.supports(x)")
byte[] data = new byte[x.getLength()]; protected Object rawToChar(RAbstractRawVector x, boolean multiple,
for (int i = 0; i < data.length; i++) { @Cached("x.access()") VectorAccess xAccess) {
byte b = x.getRawDataAt(i); try (SequentialIterator iter = xAccess.access(x)) {
if (b != 0) { if (multiple) {
data[j++] = b; String[] data = new String[xAccess.getLength(iter)];
while (xAccess.next(iter)) {
byte value = xAccess.getRaw(iter);
data[iter.getIndex()] = createString(value);
}
return RDataFactory.createStringVector(data, RDataFactory.COMPLETE_VECTOR);
} else {
int j = 0;
byte[] data = new byte[xAccess.getLength(iter)];
while (xAccess.next(iter)) {
byte b = xAccess.getRaw(iter);
if (b != 0) {
data[j++] = b;
}
} }
return createString(j, data);
} }
result = RDataFactory.createStringVectorFromScalar(new String(data, 0, j));
} }
return result; }
@Specialization(replaces = "rawToChar")
@TruffleBoundary
protected Object rawToCharGeneric(RAbstractRawVector x, boolean multiple) {
return rawToChar(x, multiple, x.slowPathAccess());
} }
} }
...@@ -114,23 +142,30 @@ public class RawFunctions { ...@@ -114,23 +142,30 @@ public class RawFunctions {
casts.arg("n").defaultError(RError.Message.MUST_BE_SMALL_INT, "shift").asIntegerVector().findFirst().mustNotBeNA().mustBe(gte(-8).and(lte(8))); casts.arg("n").defaultError(RError.Message.MUST_BE_SMALL_INT, "shift").asIntegerVector().findFirst().mustNotBeNA().mustBe(gte(-8).and(lte(8)));
} }
@Specialization @Specialization(guards = "xAccess.supports(x)")
protected RRawVector rawShift(RAbstractRawVector x, int n, protected RRawVector rawShift(RAbstractRawVector x, int n,
@Cached("createBinaryProfile()") ConditionProfile negativeShiftProfile) { @Cached("createBinaryProfile()") ConditionProfile negativeShiftProfile,
byte[] data = new byte[x.getLength()]; @Cached("x.access()") VectorAccess xAccess) {
if (negativeShiftProfile.profile(n < 0)) { try (SequentialIterator iter = xAccess.access(x)) {
for (int i = 0; i < data.length; i++) { byte[] data = new byte[xAccess.getLength(iter)];
data[i] = (byte) ((x.getRawDataAt(i) & 0xff) >> (-n)); if (negativeShiftProfile.profile(n < 0)) {
} while (xAccess.next(iter)) {
} else { data[iter.getIndex()] = (byte) ((xAccess.getRaw(iter) & 0xff) >> (-n));
for (int i = 0; i < data.length; i++) { }
data[i] = (byte) (x.getRawDataAt(i) << n); } else {
while (xAccess.next(iter)) {
data[iter.getIndex()] = (byte) (xAccess.getRaw(iter) << n);
}
} }
return RDataFactory.createRawVector(data);
} }
return RDataFactory.createRawVector(data);
} }
}
// TODO the rest of the functions
@Specialization(replaces = "rawShift")
@TruffleBoundary
protected RRawVector rawShiftGeneric(RAbstractRawVector x, int n,
@Cached("createBinaryProfile()") ConditionProfile negativeShiftProfile) {
return rawShift(x, n, negativeShiftProfile, x.slowPathAccess());
}
}
} }
...@@ -25,6 +25,8 @@ package com.oracle.truffle.r.nodes.builtin.base; ...@@ -25,6 +25,8 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef; import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
...@@ -32,6 +34,8 @@ import com.oracle.truffle.r.runtime.RError; ...@@ -32,6 +34,8 @@ import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
@RBuiltin(name = "rawToBits", kind = INTERNAL, parameterNames = {"x"}, behavior = PURE) @RBuiltin(name = "rawToBits", kind = INTERNAL, parameterNames = {"x"}, behavior = PURE)
public abstract class RawToBits extends RBuiltinNode.Arg1 { public abstract class RawToBits extends RBuiltinNode.Arg1 {
...@@ -41,17 +45,26 @@ public abstract class RawToBits extends RBuiltinNode.Arg1 { ...@@ -41,17 +45,26 @@ public abstract class RawToBits extends RBuiltinNode.Arg1 {
casts.arg("x").mustNotBeNull(RError.Message.ARGUMENT_MUST_BE_RAW_VECTOR, "x").mustBe(Predef.rawValue(), RError.Message.ARGUMENT_MUST_BE_RAW_VECTOR, "x"); casts.arg("x").mustNotBeNull(RError.Message.ARGUMENT_MUST_BE_RAW_VECTOR, "x").mustBe(Predef.rawValue(), RError.Message.ARGUMENT_MUST_BE_RAW_VECTOR, "x");
} }
@Specialization @Specialization(guards = "xAccess.supports(x)")
protected RAbstractRawVector rawToBits(RAbstractRawVector x) { protected RAbstractRawVector rawToBits(RAbstractRawVector x,
byte[] result = new byte[8 * x.getLength()]; @Cached("x.access()") VectorAccess xAccess) {
int pos = 0; try (SequentialIterator iter = xAccess.access(x)) {
for (int j = 0; j < x.getLength(); j++) { byte[] result = new byte[8 * x.getLength()];
byte temp = x.getRawDataAt(j); int pos = 0;
for (int i = 0; i < 8; i++) { while (xAccess.next(iter)) {
result[pos++] = (byte) (temp & 1); byte temp = xAccess.getRaw(iter);
temp >>= 1; for (int i = 0; i < 8; i++) {
result[pos++] = (byte) (temp & 1);
temp >>= 1;
}
} }
return RDataFactory.createRawVector(result);
} }
return RDataFactory.createRawVector(result); }
@Specialization(replaces = "rawToBits")
@TruffleBoundary
protected RAbstractRawVector rawToBitsGeneric(RAbstractRawVector x) {
return rawToBits(x, x.slowPathAccess());
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment