Skip to content
Snippets Groups Projects
Commit dbb982e0 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert unary reduce and not nodes to VectorAccess

parent 4a53fb69
No related branches found
No related tags found
No related merge requests found
...@@ -65,8 +65,8 @@ public abstract class Range extends RBuiltinNode.Arg3 { ...@@ -65,8 +65,8 @@ public abstract class Range extends RBuiltinNode.Arg3 {
@Specialization(guards = "args.getLength() == 1") @Specialization(guards = "args.getLength() == 1")
protected RVector<?> rangeLengthOne(RArgsValuesAndNames args, boolean naRm, boolean finite) { protected RVector<?> rangeLengthOne(RArgsValuesAndNames args, boolean naRm, boolean finite) {
Object min = minReduce.executeReduce(args.getArgument(0), naRm, finite); Object min = minReduce.executeReduce(args.getArgument(0), naRm || finite, finite);
Object max = maxReduce.executeReduce(args.getArgument(0), naRm, finite); Object max = maxReduce.executeReduce(args.getArgument(0), naRm || finite, finite);
return createResult(min, max); return createResult(min, max);
} }
...@@ -84,8 +84,8 @@ public abstract class Range extends RBuiltinNode.Arg3 { ...@@ -84,8 +84,8 @@ public abstract class Range extends RBuiltinNode.Arg3 {
protected RVector<?> range(RArgsValuesAndNames args, boolean naRm, boolean finite, protected RVector<?> range(RArgsValuesAndNames args, boolean naRm, boolean finite,
@Cached("create()") Combine combine) { @Cached("create()") Combine combine) {
Object combined = combine.executeCombine(args, false); Object combined = combine.executeCombine(args, false);
Object min = minReduce.executeReduce(combined, naRm, finite); Object min = minReduce.executeReduce(combined, naRm || finite, finite);
Object max = maxReduce.executeReduce(combined, naRm, finite); Object max = maxReduce.executeReduce(combined, naRm || finite, finite);
return createResult(min, max); return createResult(min, max);
} }
} }
...@@ -32,34 +32,37 @@ import com.oracle.truffle.api.dsl.Fallback; ...@@ -32,34 +32,37 @@ import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.interop.Message;
import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.RType;
import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory;
import com.oracle.truffle.r.runtime.data.RLogicalVector;
import com.oracle.truffle.r.runtime.data.RRaw; import com.oracle.truffle.r.runtime.data.RRaw;
import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorReuse;
import com.oracle.truffle.r.runtime.interop.ForeignArray2R; import com.oracle.truffle.r.runtime.interop.ForeignArray2R;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
import com.oracle.truffle.r.runtime.ops.na.NAProfile; import com.oracle.truffle.r.runtime.ops.na.NAProfile;
@ImportStatic({RRuntime.class}) @ImportStatic({RRuntime.class, ForeignArray2R.class, Message.class, RType.class})
@RBuiltin(name = "!", kind = PRIMITIVE, parameterNames = {""}, dispatch = OPS_GROUP_GENERIC, behavior = PURE_ARITHMETIC) @RBuiltin(name = "!", kind = PRIMITIVE, parameterNames = {""}, dispatch = OPS_GROUP_GENERIC, behavior = PURE_ARITHMETIC)
public abstract class UnaryNotNode extends RBuiltinNode.Arg1 { public abstract class UnaryNotNode extends RBuiltinNode.Arg1 {
private final NACheck na = NACheck.create();
private final NAProfile naProfile = NAProfile.create(); private final NAProfile naProfile = NAProfile.create();
private final ConditionProfile zeroLengthProfile = ConditionProfile.createBinaryProfile();
@Child private GetDimAttributeNode getDims = GetDimAttributeNode.create();
@Child private GetNamesAttributeNode getNames = GetNamesAttributeNode.create();
@Child private GetDimNamesAttributeNode getDimNames = GetDimNamesAttributeNode.create();
static { static {
Casts.noCasts(UnaryNotNode.class); Casts.noCasts(UnaryNotNode.class);
...@@ -77,10 +80,6 @@ public abstract class UnaryNotNode extends RBuiltinNode.Arg1 { ...@@ -77,10 +80,6 @@ public abstract class UnaryNotNode extends RBuiltinNode.Arg1 {
return RRuntime.asLogical(operand == 0); return RRuntime.asLogical(operand == 0);
} }
private static byte not(RComplex operand) {
return RRuntime.asLogical(operand.getRealPart() == 0 && operand.getImaginaryPart() == 0);
}
private static byte notRaw(RRaw operand) { private static byte notRaw(RRaw operand) {
return notRaw(operand.getValue()); return notRaw(operand.getValue());
} }
...@@ -109,112 +108,83 @@ public abstract class UnaryNotNode extends RBuiltinNode.Arg1 { ...@@ -109,112 +108,83 @@ public abstract class UnaryNotNode extends RBuiltinNode.Arg1 {
return RDataFactory.createRaw(notRaw(operand)); return RDataFactory.createRaw(notRaw(operand));
} }
@Specialization @Specialization(guards = {"vectorAccess.supports(vector)", "reuse.supports(vector)"})
protected RLogicalVector doLogicalVector(RLogicalVector vector) { protected RAbstractVector doLogicalVectorCached(RAbstractLogicalVector vector,
int length = vector.getLength(); @Cached("vector.access()") VectorAccess vectorAccess,
byte[] result; @Cached("createTemporary(vector)") VectorReuse reuse) {
if (zeroLengthProfile.profile(length == 0)) { RAbstractVector result = reuse.getResult(vector);
result = new byte[0]; VectorAccess resultAccess = reuse.access(result);
} else { try (SequentialIterator vectorIter = vectorAccess.access(vector); SequentialIterator resultIter = resultAccess.access(result)) {
na.enable(vector); while (vectorAccess.next(vectorIter) && resultAccess.next(resultIter)) {
result = new byte[length]; byte value = vectorAccess.getLogical(vectorIter);
for (int i = 0; i < length; i++) { resultAccess.setLogical(resultIter, vectorAccess.na.check(value) ? RRuntime.LOGICAL_NA : not(value));
byte value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
} }
} }
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA()); result.setComplete(vectorAccess.na.neverSeenNA());
resultVector.copyAttributesFrom(vector); return result;
return resultVector;
}
@Specialization
protected RLogicalVector doIntVector(RAbstractIntVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
na.enable(vector);
result = new byte[length];
for (int i = 0; i < length; i++) {
int value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
}
}
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA());
copyNamesDimsDimNames(vector, resultVector);
return resultVector;
}
@Specialization
protected RLogicalVector doDoubleVector(RAbstractDoubleVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
na.enable(vector);
result = new byte[length];
for (int i = 0; i < length; i++) {
double value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
}
}
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA());
copyNamesDimsDimNames(vector, resultVector);
return resultVector;
}
@Specialization
protected RLogicalVector doComplexVector(RAbstractComplexVector vector) {
int length = vector.getLength();
byte[] result;
if (zeroLengthProfile.profile(length == 0)) {
result = new byte[0];
} else {
na.enable(vector);
result = new byte[length];
for (int i = 0; i < length; i++) {
RComplex value = vector.getDataAt(i);
result[i] = na.check(value) ? RRuntime.LOGICAL_NA : not(value);
}
}
RLogicalVector resultVector = RDataFactory.createLogicalVector(result, na.neverSeenNA());
copyNamesDimsDimNames(vector, resultVector);
return resultVector;
} }
@Specialization(replaces = "doLogicalVectorCached")
@TruffleBoundary @TruffleBoundary
private void copyNamesDimsDimNames(RAbstractVector vector, RLogicalVector resultVector) { protected RAbstractVector doLogicalGenericGeneric(RAbstractLogicalVector vector,
resultVector.copyNamesDimsDimNamesFrom(vector, this); @Cached("createTemporaryGeneric()") VectorReuse reuse) {
} return doLogicalVectorCached(vector, vector.slowPathAccess(), reuse);
}
@Specialization
protected RRawVector doRawVector(RRawVector vector) { @Specialization(guards = {"vectorAccess.supports(vector)", "!isRAbstractLogicalVector(vector)"})
int length = vector.getLength(); protected RAbstractVector doVectorCached(RAbstractVector vector,
byte[] result; @Cached("vector.access()") VectorAccess vectorAccess,
if (zeroLengthProfile.profile(length == 0)) { @Cached("createNew(Logical)") VectorAccess resultAccess,
result = new byte[0]; @Cached("createNew(Raw)") VectorAccess rawResultAccess,
} else { @Cached("create()") VectorFactory factory) {
result = new byte[length]; try (SequentialIterator vectorIter = vectorAccess.access(vector)) {
for (int i = 0; i < length; i++) { int length = vectorAccess.getLength(vectorIter);
result[i] = notRaw(vector.getRawDataAt(i)); RAbstractVector result;
switch (vectorAccess.getType()) {
case Character:
case List:
case Expression:
// special cases:
if (length == 0) {
return factory.createEmptyLogicalVector();
} else {
throw error(RError.Message.INVALID_ARG_TYPE);
}
case Raw:
result = factory.createRawVector(length);
try (SequentialIterator resultIter = rawResultAccess.access(result)) {
// raw does not produce a logical result, but (255 - value)
while (vectorAccess.next(vectorIter) && rawResultAccess.next(resultIter)) {
rawResultAccess.setRaw(resultIter, notRaw(vectorAccess.getRaw(vectorIter)));
}
}
((RVector<?>) result).copyAttributesFrom(vector);
break;
default:
result = factory.createLogicalVector(length, false);
try (SequentialIterator resultIter = resultAccess.access(result)) {
while (vectorAccess.next(vectorIter) && resultAccess.next(resultIter)) {
byte value = vectorAccess.getLogical(vectorIter);
resultAccess.setLogical(resultIter, vectorAccess.na.check(value) ? RRuntime.LOGICAL_NA : not(value));
}
}
if (vectorAccess.getType() == RType.Logical) {
((RVector<?>) result).copyAttributesFrom(vector);
} else {
factory.reinitializeAttributes((RVector<?>) result, getDims.getDimensions(vector), getNames.getNames(vector), getDimNames.getDimNames(vector));
}
break;
} }
result.setComplete(vectorAccess.na.neverSeenNA());
return result;
} }
RRawVector resultVector = RDataFactory.createRawVector(result);
resultVector.copyAttributesFrom(vector);
return resultVector;
} }
@Specialization(guards = {"vector.getLength() == 0"}) @Specialization(replaces = "doVectorCached", guards = "!isRAbstractLogicalVector(vector)")
protected RLogicalVector doStringVector(@SuppressWarnings("unused") RAbstractStringVector vector) { @TruffleBoundary
return RDataFactory.createEmptyLogicalVector(); protected RAbstractVector doGenericGeneric(RAbstractVector vector,
} @Cached("create()") VectorFactory factory) {
return doVectorCached(vector, vector.slowPathAccess(), VectorAccess.createSlowPathNew(RType.Logical), VectorAccess.createSlowPathNew(RType.Raw), factory);
@Specialization(guards = {"list.getLength() == 0"})
protected RLogicalVector doList(@SuppressWarnings("unused") RList list) {
return RDataFactory.createEmptyLogicalVector();
} }
@Specialization(guards = {"isForeignObject(obj)"}) @Specialization(guards = {"isForeignObject(obj)"})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment