diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConvertBooleanNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConvertBooleanNode.java index ff55cf9f8b2ea74de50f2a1f1f5752d537e31fd9..0541c5523b79038b7dbd6c335f7e281640464bd9 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConvertBooleanNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConvertBooleanNode.java @@ -32,18 +32,17 @@ import com.oracle.truffle.api.frame.VirtualFrame; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RComplexVector; -import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RLogicalVector; +import com.oracle.truffle.r.runtime.data.RLogical; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RRaw; -import com.oracle.truffle.r.runtime.data.RRawVector; -import com.oracle.truffle.r.runtime.data.RStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractAtomicVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.interop.ForeignArray2R; import com.oracle.truffle.r.runtime.nodes.RNode; import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; @@ -53,6 +52,8 @@ import com.oracle.truffle.r.runtime.ops.na.NAProfile; @ImportStatic(RRuntime.class) public abstract class ConvertBooleanNode extends RNode { + protected static final int ATOMIC_VECTOR_LIMIT = 8; + private final NAProfile naProfile = NAProfile.create(); private final BranchProfile invalidElementCountBranch = BranchProfile.create(); @Child private ConvertBooleanNode recursiveConvertBoolean; @@ -120,10 +121,16 @@ public abstract class ConvertBooleanNode extends RNode { return RRuntime.raw2logical(value.getValue()); } - private void checkLength(RAbstractVector value) { - if (value.getLength() != 1) { + @Specialization + protected byte doRLogical(RLogical value) { + // fast path for very common case, handled also in doAtomicVector + return value.getValue(); + } + + private void checkLength(int length) { + if (length != 1) { invalidElementCountBranch.enter(); - if (value.getLength() == 0) { + if (length == 0) { throw error(RError.Message.LENGTH_ZERO); } else { warning(RError.Message.LENGTH_GT_1); @@ -131,46 +138,33 @@ public abstract class ConvertBooleanNode extends RNode { } } - @Specialization - protected byte doIntVector(RAbstractIntVector value) { - checkLength(value); - return doInt(value.getDataAt(0)); - } - - @Specialization - protected byte doDoubleVector(RAbstractDoubleVector value) { - checkLength(value); - return doDouble(value.getDataAt(0)); - } - - @Specialization - protected byte doLogicalVector(RLogicalVector value) { - checkLength(value); - return doLogical(value.getDataAt(0)); - } - - @Specialization - protected byte doComplexVector(RComplexVector value) { - checkLength(value); - return doComplex(value.getDataAt(0)); - } - - @Specialization - protected byte doStringVector(RStringVector value) { - checkLength(value); - return doString(value.getDataAt(0)); - } - - @Specialization - protected byte doRawVector(RRawVector value) { - checkLength(value); - return RRuntime.raw2logical(value.getRawDataAt(0)); + @Specialization(guards = "access.supports(value)", limit = "ATOMIC_VECTOR_LIMIT") + protected byte doVector(RAbstractVector value, + @Cached("value.access()") VectorAccess access) { + SequentialIterator it = access.access(value); + checkLength(access.getLength(it)); + access.next(it); + switch (access.getType()) { + case Integer: + return doInt(access.getInt(it)); + case Double: + return doDouble(access.getDouble(it)); + case Raw: + return RRuntime.raw2logical(access.getRaw(it)); + case Logical: + return doLogical(access.getLogical(it)); + case Character: + return doString(access.getString(it)); + case Complex: + return doComplex(access.getComplex(it)); + default: + throw error(RError.Message.ARGUMENT_NOT_INTERPRETABLE_LOGICAL); + } } - @Specialization - protected byte doRawVector(RList value) { - checkLength(value); - throw error(RError.Message.ARGUMENT_NOT_INTERPRETABLE_LOGICAL); + @Specialization(replaces = "doVector") + protected byte doVectorGeneric(RAbstractVector value) { + return doVector(value, value.slowPathAccess()); } @Specialization(guards = "isForeignObject(obj)") @@ -192,8 +186,7 @@ public abstract class ConvertBooleanNode extends RNode { if (node instanceof ConvertBooleanNode) { return (ConvertBooleanNode) node; } - ConvertBooleanNode result = ConvertBooleanNodeGen.create(node.asRNode()); - return result; + return ConvertBooleanNodeGen.create(node.asRNode()); } @Override