Skip to content
Snippets Groups Projects
Commit dbb982e0 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert unary reduce and not nodes to VectorAccess

parent 4a53fb69
No related branches found
No related tags found
No related merge requests found
......@@ -65,8 +65,8 @@ public abstract class Range extends RBuiltinNode.Arg3 {
@Specialization(guards = "args.getLength() == 1")
protected RVector<?> rangeLengthOne(RArgsValuesAndNames args, boolean naRm, boolean finite) {
Object min = minReduce.executeReduce(args.getArgument(0), naRm, finite);
Object max = maxReduce.executeReduce(args.getArgument(0), naRm, finite);
Object min = minReduce.executeReduce(args.getArgument(0), naRm || finite, finite);
Object max = maxReduce.executeReduce(args.getArgument(0), naRm || finite, finite);
return createResult(min, max);
}
......@@ -84,8 +84,8 @@ public abstract class Range extends RBuiltinNode.Arg3 {
protected RVector<?> range(RArgsValuesAndNames args, boolean naRm, boolean finite,
@Cached("create()") Combine combine) {
Object combined = combine.executeCombine(args, false);
Object min = minReduce.executeReduce(combined, naRm, finite);
Object max = maxReduce.executeReduce(combined, naRm, finite);
Object min = minReduce.executeReduce(combined, naRm || finite, finite);
Object max = maxReduce.executeReduce(combined, naRm || finite, finite);
return createResult(min, max);
}
}
......@@ -32,34 +32,37 @@ import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.interop.Message;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.RType;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RLogicalVector;
import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory;
import com.oracle.truffle.r.runtime.data.RRaw;
import com.oracle.truffle.r.runtime.data.RRawVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorReuse;
import com.oracle.truffle.r.runtime.interop.ForeignArray2R;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
import com.oracle.truffle.r.runtime.ops.na.NAProfile;
@ImportStatic({RRuntime.class})
@ImportStatic({RRuntime.class, ForeignArray2R.class, Message.class, RType.class})
@RBuiltin(name = "!", kind = PRIMITIVE, parameterNames = {""}, dispatch = OPS_GROUP_GENERIC, behavior = PURE_ARITHMETIC)
public abstract class UnaryNotNode extends RBuiltinNode.Arg1 {
private final NACheck na = NACheck.create();
private final NAProfile naProfile = NAProfile.create();
private final ConditionProfile zeroLengthProfile = ConditionProfile.createBinaryProfile();
@Child private GetDimAttributeNode getDims = GetDimAttributeNode.create();
@Child private GetNamesAttributeNode getNames = GetNamesAttributeNode.create();
@Child private GetDimNamesAttributeNode getDimNames = GetDimNamesAttributeNode.create();
static {
Casts.noCasts(UnaryNotNode.class);
......@@ -77,10 +80,6 @@ public abstract class UnaryNotNode extends RBuiltinNode.Arg1 {
return RRuntime.asLogical(operand == 0);
}
private static byte not(RComplex operand) {
return RRuntime.asLogical(operand.getRealPart() == 0 && operand.getImaginaryPart() == 0);
}
private static byte notRaw(RRaw operand) {
return notRaw(operand.getValue());
}
......@@ -109,112 +108,83 @@ public abstract class UnaryNotNode extends RBuiltinNode.Arg1 {
return RDataFactory.createRaw(notRaw(operand));
}
@Specialization
protected RLogicalVector doLogicalVector(RLogicalVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
na.enable(vector);
result = new byte[length];
for (int i = 0; i < length; i++) {
byte value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
@Specialization(guards = {"vectorAccess.supports(vector)", "reuse.supports(vector)"})
protected RAbstractVector doLogicalVectorCached(RAbstractLogicalVector vector,
@Cached("vector.access()") VectorAccess vectorAccess,
@Cached("createTemporary(vector)") VectorReuse reuse) {
RAbstractVector result = reuse.getResult(vector);
VectorAccess resultAccess = reuse.access(result);
try (SequentialIterator vectorIter = vectorAccess.access(vector); SequentialIterator resultIter = resultAccess.access(result)) {
while (vectorAccess.next(vectorIter) && resultAccess.next(resultIter)) {
byte value = vectorAccess.getLogical(vectorIter);
resultAccess.setLogical(resultIter, vectorAccess.na.check(value) ? RRuntime.LOGICAL_NA : not(value));
}
}
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA());
resultVector.copyAttributesFrom(vector);
return resultVector;
}
@Specialization
protected RLogicalVector doIntVector(RAbstractIntVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
na.enable(vector);
result = new byte[length];
for (int i = 0; i < length; i++) {
int value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
}
}
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA());
copyNamesDimsDimNames(vector, resultVector);
return resultVector;
}
@Specialization
protected RLogicalVector doDoubleVector(RAbstractDoubleVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
na.enable(vector);
result = new byte[length];
for (int i = 0; i < length; i++) {
double value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
}
}
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA());
copyNamesDimsDimNames(vector, resultVector);
return resultVector;
}
@Specialization
protected RLogicalVector doComplexVector(RAbstractComplexVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
na.enable(vector);
result = new byte[length];
for (int i = 0; i < length; i++) {
RComplex value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
}
}
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA());
copyNamesDimsDimNames(vector, resultVector);
return resultVector;
result.setComplete(vectorAccess.na.neverSeenNA());
return result;
}
@Specialization(replaces = "doLogicalVectorCached")
@TruffleBoundary
private void copyNamesDimsDimNames(RAbstractVector vector, RLogicalVector resultVector) {
resultVector.copyNamesDimsDimNamesFrom(vector, this);
}
@Specialization
protected RRawVector doRawVector(RRawVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
result = new byte[length];
for (int i = 0; i < length; i++) {
result[i] = notRaw(vector.getRawDataAt(i));
protected RAbstractVector doLogicalGenericGeneric(RAbstractLogicalVector vector,
@Cached("createTemporaryGeneric()") VectorReuse reuse) {
return doLogicalVectorCached(vector, vector.slowPathAccess(), reuse);
}
@Specialization(guards = {"vectorAccess.supports(vector)", "!isRAbstractLogicalVector(vector)"})
protected RAbstractVector doVectorCached(RAbstractVector vector,
@Cached("vector.access()") VectorAccess vectorAccess,
@Cached("createNew(Logical)") VectorAccess resultAccess,
@Cached("createNew(Raw)") VectorAccess rawResultAccess,
@Cached("create()") VectorFactory factory) {
try (SequentialIterator vectorIter = vectorAccess.access(vector)) {
int length = vectorAccess.getLength(vectorIter);
RAbstractVector result;
switch (vectorAccess.getType()) {
case Character:
case List:
case Expression:
// special cases:
if (length == 0) {
return factory.createEmptyLogicalVector();
} else {
throw error(RError.Message.INVALID_ARG_TYPE);
}
case Raw:
result = factory.createRawVector(length);
try (SequentialIterator resultIter = rawResultAccess.access(result)) {
// raw does not produce a logical result, but (255 - value)
while (vectorAccess.next(vectorIter) && rawResultAccess.next(resultIter)) {
rawResultAccess.setRaw(resultIter, notRaw(vectorAccess.getRaw(vectorIter)));
}
}
((RVector<?>) result).copyAttributesFrom(vector);
break;
default:
result = factory.createLogicalVector(length, false);
try (SequentialIterator resultIter = resultAccess.access(result)) {
while (vectorAccess.next(vectorIter) && resultAccess.next(resultIter)) {
byte value = vectorAccess.getLogical(vectorIter);
resultAccess.setLogical(resultIter, vectorAccess.na.check(value) ? RRuntime.LOGICAL_NA : not(value));
}
}
if (vectorAccess.getType() == RType.Logical) {
((RVector<?>) result).copyAttributesFrom(vector);
} else {
factory.reinitializeAttributes((RVector<?>) result, getDims.getDimensions(vector), getNames.getNames(vector), getDimNames.getDimNames(vector));
}
break;
}
result.setComplete(vectorAccess.na.neverSeenNA());
return result;
}
RRawVector resultVector = RDataFactory.createRawVector(result);
resultVector.copyAttributesFrom(vector);
return resultVector;
}
@Specialization(guards = {"vector.getLength() == 0"})
protected RLogicalVector doStringVector(@SuppressWarnings("unused") RAbstractStringVector vector) {
return RDataFactory.createEmptyLogicalVector();
}
@Specialization(guards = {"list.getLength() == 0"})
protected RLogicalVector doList(@SuppressWarnings("unused") RList list) {
return RDataFactory.createEmptyLogicalVector();
@Specialization(replaces = "doVectorCached", guards = "!isRAbstractLogicalVector(vector)")
@TruffleBoundary
protected RAbstractVector doGenericGeneric(RAbstractVector vector,
@Cached("create()") VectorFactory factory) {
return doVectorCached(vector, vector.slowPathAccess(), VectorAccess.createSlowPathNew(RType.Logical), VectorAccess.createSlowPathNew(RType.Raw), factory);
}
@Specialization(guards = {"isForeignObject(obj)"})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment