diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java index 6af290e8b548d3f283e78459085726d2507c3bb6..96d5b62296049ad4bc60fde9876f0c20d9535a03 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryArithmeticNode.java @@ -87,29 +87,27 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 { } @Specialization(limit = "CACHE_LIMIT", guards = {"cached != null", "cached.isSupported(left, right)"}) - protected Object doNumericVectorCached(Object left, Object right, + protected Object doNumericVectorCached(RAbstractVector left, RAbstractVector right, @Cached("createFastCached(left, right)") BinaryMapNode cached) { return cached.apply(left, right); } @Specialization(replaces = "doNumericVectorCached", guards = {"isNumericVector(left)", "isNumericVector(right)"}) @TruffleBoundary - protected Object doNumericVectorGeneric(Object left, Object right, + protected Object doNumericVectorGeneric(RAbstractVector left, RAbstractVector right, @Cached("binary.createOperation()") BinaryArithmetic arithmetic, - @Cached("new(createCached(arithmetic, left, right))") GenericNumericVectorNode generic) { - RAbstractVector leftVector = (RAbstractVector) left; - RAbstractVector rightVector = (RAbstractVector) right; - return generic.get(arithmetic, leftVector, rightVector).apply(leftVector, rightVector); + @Cached("createGeneric()") GenericNumericVectorNode generic) { + return generic.get(arithmetic, left, right).apply(left, right); } - protected BinaryMapNode createFastCached(Object left, Object right) { + protected BinaryMapNode createFastCached(RAbstractVector left, RAbstractVector right) { if (isNumericVector(left) && isNumericVector(right)) { - return createCached(binary.createOperation(), left, right); + return createCached(binary.createOperation(), left, right, false); } return null; } - protected static boolean isNumericVector(Object value) { + protected static boolean isNumericVector(RAbstractVector value) { return value instanceof RAbstractIntVector || value instanceof RAbstractDoubleVector || value instanceof RAbstractComplexVector || value instanceof RAbstractLogicalVector; } @@ -133,19 +131,19 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 { } @Specialization(guards = {"isNumericVector(right)"}) - protected Object doLeftNull(@SuppressWarnings("unused") RNull left, Object right, + protected Object doLeftNull(@SuppressWarnings("unused") RNull left, RAbstractVector right, @Cached("createClassProfile()") ValueProfile classProfile) { - RType rType = ((RAbstractVector) classProfile.profile(right)).getRType(); - if (rType == RType.Complex) { + RType type = classProfile.profile(right).getRType(); + if (type == RType.Complex) { return RDataFactory.createEmptyComplexVector(); } else { - if (rType == RType.Integer || rType == RType.Logical) { + if (type == RType.Integer || type == RType.Logical) { if (operation instanceof BinaryArithmetic.Div || operation instanceof BinaryArithmetic.Pow) { return RType.Double.getEmpty(); } else { return RType.Integer.getEmpty(); } - } else if (rType == RType.Double) { + } else if (type == RType.Double) { return RType.Double.getEmpty(); } else { throw error(Message.NON_NUMERIC_BINARY); @@ -154,7 +152,7 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 { } @Specialization(guards = {"isNumericVector(left)"}) - protected Object doRightNull(Object left, RNull right, + protected Object doRightNull(RAbstractVector left, RNull right, @Cached("createClassProfile()") ValueProfile classProfile) { return doLeftNull(right, left, classProfile); } @@ -164,7 +162,7 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 { throw error(Message.NON_NUMERIC_BINARY); } - protected static BinaryMapNode createCached(BinaryArithmetic innerArithmetic, Object left, Object right) { + protected static BinaryMapNode createCached(BinaryArithmetic innerArithmetic, Object left, Object right, boolean isGeneric) { RAbstractVector leftVector = (RAbstractVector) left; RAbstractVector rightVector = (RAbstractVector) right; @@ -174,22 +172,22 @@ public abstract class BinaryArithmeticNode extends RBuiltinNode.Arg2 { resultType = RType.Double; } - return BinaryMapNode.create(new BinaryMapArithmeticFunctionNode(innerArithmetic), leftVector, rightVector, argumentType, resultType, true); + return BinaryMapNode.create(new BinaryMapArithmeticFunctionNode(innerArithmetic), leftVector, rightVector, argumentType, resultType, true, isGeneric); + } + + protected static GenericNumericVectorNode createGeneric() { + return new GenericNumericVectorNode(); } protected static final class GenericNumericVectorNode extends TruffleBoundaryNode { @Child private BinaryMapNode cached; - public GenericNumericVectorNode(BinaryMapNode cachedOperation) { - this.cached = insert(cachedOperation); - } - public BinaryMapNode get(BinaryArithmetic arithmetic, RAbstractVector left, RAbstractVector right) { CompilerAsserts.neverPartOfCompilation(); BinaryMapNode map = cached; - if (!map.isSupported(left, right)) { - cached = map = map.replace(createCached(arithmetic, left, right)); + if (map == null || !map.isSupported(left, right)) { + cached = map = insert(createCached(arithmetic, left, right, true)); } return map; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java index bae281a74708b2017a5659fe5a48430fa37bf444..b367673e3376d89d14d59cd085f5ee232c0b4cbd 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/binary/BinaryBooleanNode.java @@ -104,24 +104,22 @@ public abstract class BinaryBooleanNode extends RBuiltinNode.Arg2 { } @Specialization(limit = "CACHE_LIMIT", guards = {"cached != null", "cached.isSupported(left, right)"}) - protected Object doNumericVectorCached(Object left, Object right, + protected Object doNumericVectorCached(RAbstractVector left, RAbstractVector right, @Cached("createFastCached(left, right)") BinaryMapNode cached) { return cached.apply(left, right); } @Specialization(replaces = "doNumericVectorCached", guards = "isSupported(left, right)") @TruffleBoundary - protected Object doNumericVectorGeneric(Object left, Object right, + protected Object doNumericVectorGeneric(RAbstractVector left, RAbstractVector right, @Cached("factory.createOperation()") BooleanOperation operation, - @Cached("new(createCached(operation, left, right))") GenericNumericVectorNode generic) { - RAbstractVector leftVector = (RAbstractVector) left; - RAbstractVector rightVector = (RAbstractVector) right; - return generic.get(operation, leftVector, rightVector).apply(leftVector, rightVector); + @Cached("createGeneric()") GenericNumericVectorNode generic) { + return generic.get(operation, left, right).apply(left, right); } protected BinaryMapNode createFastCached(Object left, Object right) { if (isSupported(left, right)) { - return createCached(factory.createOperation(), left, right); + return createCached(factory.createOperation(), left, right, false); } return null; } @@ -243,7 +241,7 @@ public abstract class BinaryBooleanNode extends RBuiltinNode.Arg2 { throw error(Message.OPERATIONS_NUMERIC_LOGICAL_COMPLEX); } - protected static BinaryMapNode createCached(BooleanOperation operation, Object left, Object right) { + protected static BinaryMapNode createCached(BooleanOperation operation, Object left, Object right, boolean isGeneric) { RAbstractVector leftVector = (RAbstractVector) left; RAbstractVector rightVector = (RAbstractVector) right; @@ -255,23 +253,24 @@ public abstract class BinaryBooleanNode extends RBuiltinNode.Arg2 { resultType = RType.Logical; } - return BinaryMapNode.create(new BinaryMapBooleanFunctionNode(operation), leftVector, rightVector, argumentType, resultType, false); + return BinaryMapNode.create(new BinaryMapBooleanFunctionNode(operation), leftVector, rightVector, argumentType, resultType, false, isGeneric); + } + + protected static GenericNumericVectorNode createGeneric() { + return new GenericNumericVectorNode(); } protected static final class GenericNumericVectorNode extends TruffleBoundaryNode { @Child private BinaryMapNode cached; - public GenericNumericVectorNode(BinaryMapNode cachedOperation) { - this.cached = insert(cachedOperation); - } - private BinaryMapNode get(BooleanOperation arithmetic, RAbstractVector left, RAbstractVector right) { CompilerAsserts.neverPartOfCompilation(); - if (!cached.isSupported(left, right)) { - cached = cached.replace(createCached(arithmetic, left, right)); + BinaryMapNode map = cached; + if (map == null || !map.isSupported(left, right)) { + cached = map = insert(createCached(arithmetic, left, right, true)); } - return cached; + return map; } } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/BinaryMapNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/BinaryMapNode.java index a77f74f63d9ab98891b5b20991a0594958294612..51200dd1cf40613c23a7d68ef8b432764d1a3eec 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/BinaryMapNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/BinaryMapNode.java @@ -22,8 +22,8 @@ */ package com.oracle.truffle.r.nodes.primitive; +import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Cached; -import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.LoopConditionProfile; @@ -31,7 +31,6 @@ import com.oracle.truffle.r.nodes.attributes.CopyAttributesNode; import com.oracle.truffle.r.nodes.attributes.CopyAttributesNodeGen; import com.oracle.truffle.r.nodes.attributes.HasFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; -import com.oracle.truffle.r.nodes.primitive.BinaryMapNodeFactory.VectorMapBinaryInternalNodeGen; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RInternalError; @@ -42,110 +41,229 @@ import com.oracle.truffle.r.runtime.data.RRaw; import com.oracle.truffle.r.runtime.data.RScalarVector; import com.oracle.truffle.r.runtime.data.RShareable; import com.oracle.truffle.r.runtime.data.RVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -import com.oracle.truffle.r.runtime.data.nodes.GetDataStore; -import com.oracle.truffle.r.runtime.data.nodes.SetDataAt; -import com.oracle.truffle.r.runtime.data.nodes.VectorIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.nodes.RBaseNode; -/** - * Implements a binary map operation that maps two vectors into a single result vector of the - * maximum size of both vectors. Vectors with smaller length are repeated. The actual implementation - * is provided using a {@link BinaryMapFunctionNode}. - * - * The implementation tries to share input vectors if they are implementing {@link RShareable}. - */ -public final class BinaryMapNode extends RBaseNode { +final class BinaryMapScalarNode extends BinaryMapNode { + + @Child private VectorAccess leftAccess; + @Child private VectorAccess rightAccess; + + BinaryMapScalarNode(BinaryMapFunctionNode function, RAbstractVector left, RAbstractVector right, RType argumentType, RType resultType) { + super(function, left, right, argumentType, resultType); + this.leftAccess = left.access(); + this.rightAccess = right.access(); + } + + @Override + public boolean isSupported(RAbstractVector left, RAbstractVector right) { + return leftAccess.supports(left) && rightAccess.supports(right); + } + + @Override + public Object apply(RAbstractVector originalLeft, RAbstractVector originalRight) { + assert isSupported(originalLeft, originalRight); + RAbstractVector left = leftClass.cast(originalLeft); + RAbstractVector right = rightClass.cast(originalRight); + try (RandomIterator leftIter = leftAccess.randomAccess(left); RandomIterator rightIter = rightAccess.randomAccess(right)) { + + assert left != null; + assert right != null; + function.enable(left, right); + assert leftAccess.getLength(leftIter) == 1; + assert rightAccess.getLength(rightIter) == 1; + + switch (argumentType) { + case Raw: + byte leftValueRaw = leftAccess.getRaw(leftIter, 0); + byte rightValueRaw = rightAccess.getRaw(rightIter, 0); + switch (resultType) { + case Raw: + return RRaw.valueOf(function.applyRaw(leftValueRaw, rightValueRaw)); + case Logical: + return function.applyLogical(RRuntime.raw2int(leftValueRaw), RRuntime.raw2int(rightValueRaw)); + default: + throw RInternalError.shouldNotReachHere(); + } + case Logical: + byte leftValueLogical = leftAccess.getLogical(leftIter, 0); + byte rightValueLogical = rightAccess.getLogical(rightIter, 0); + return function.applyLogical(leftValueLogical, rightValueLogical); + case Integer: + int leftValueInt = leftAccess.getInt(leftIter, 0); + int rightValueInt = rightAccess.getInt(rightIter, 0); + switch (resultType) { + case Logical: + return function.applyLogical(leftValueInt, rightValueInt); + case Integer: + return function.applyInteger(leftValueInt, rightValueInt); + case Double: + return function.applyDouble(leftValueInt, rightValueInt); + default: + throw RInternalError.shouldNotReachHere(); + } + case Double: + double leftValueDouble = leftAccess.getDouble(leftIter, 0); + double rightValueDouble = rightAccess.getDouble(rightIter, 0); + switch (resultType) { + case Logical: + return function.applyLogical(leftValueDouble, rightValueDouble); + case Double: + return function.applyDouble(leftValueDouble, rightValueDouble); + default: + throw RInternalError.shouldNotReachHere(); + } + case Complex: + RComplex leftValueComplex = leftAccess.getComplex(leftIter, 0); + RComplex rightValueComplex = rightAccess.getComplex(rightIter, 0); + switch (resultType) { + case Logical: + return function.applyLogical(leftValueComplex, rightValueComplex); + case Complex: + return function.applyComplex(leftValueComplex, rightValueComplex); + default: + throw RInternalError.shouldNotReachHere(); + } + case Character: + String leftValueString = leftAccess.getString(leftIter, 0); + String rightValueString = rightAccess.getString(rightIter, 0); + switch (resultType) { + case Logical: + return function.applyLogical(leftValueString, rightValueString); + default: + throw RInternalError.shouldNotReachHere(); + } + default: + throw RInternalError.shouldNotReachHere(); + } + } + } +} + +final class BinaryMapVectorNode extends BinaryMapNode { @Child private VectorMapBinaryInternalNode vectorNode; - @Child private BinaryMapFunctionNode function; @Child private CopyAttributesNode copyAttributes; @Child private GetDimAttributeNode getLeftDimNode = GetDimAttributeNode.create(); @Child private GetDimAttributeNode getRightDimNode = GetDimAttributeNode.create(); @Child private HasFixedAttributeNode hasLeftDimNode = HasFixedAttributeNode.createDim(); @Child private HasFixedAttributeNode hasRightDimNode = HasFixedAttributeNode.createDim(); + @Child private VectorAccess fastLeftAccess; + @Child private VectorAccess fastRightAccess; + @Child private VectorAccess resultAccess; + // profiles - private final Class<? extends RAbstractVector> leftClass; - private final Class<? extends RAbstractVector> rightClass; - private final VectorLengthProfile leftLengthProfile = VectorLengthProfile.create(); - private final VectorLengthProfile rightLengthProfile = VectorLengthProfile.create(); + private final VectorLengthProfile leftLengthProfile; + private final VectorLengthProfile rightLengthProfile; private final ConditionProfile dimensionsProfile; private final ConditionProfile maxLengthProfile; - private final ConditionProfile leftIsNAProfile = ConditionProfile.createBinaryProfile(); - private final ConditionProfile rightIsNAProfile = ConditionProfile.createBinaryProfile(); - private final ConditionProfile seenEmpty = ConditionProfile.createBinaryProfile(); + private final ConditionProfile seenEmpty; private final ConditionProfile shareLeft; private final ConditionProfile shareRight; - private final RType argumentType; - private final RType resultType; + private final ConditionProfile leftIsNAProfile; + private final ConditionProfile rightIsNAProfile; // compile-time optimization flags - private final boolean scalarTypes; private final boolean mayContainMetadata; private final boolean mayFoldConstantTime; private final boolean mayShareLeft; private final boolean mayShareRight; - - private BinaryMapNode(BinaryMapFunctionNode function, RAbstractVector left, RAbstractVector right, RType argumentType, RType resultType, boolean copyAttributes) { - this.function = function; - this.leftClass = left.getClass(); - this.rightClass = right.getClass(); + private final boolean isGeneric; + + BinaryMapVectorNode(BinaryMapFunctionNode function, RAbstractVector left, RAbstractVector right, RType argumentType, RType resultType, boolean copyAttributes, boolean isGeneric) { + super(function, left, right, argumentType, resultType); + this.leftLengthProfile = VectorLengthProfile.create(); + this.rightLengthProfile = VectorLengthProfile.create(); + this.seenEmpty = ConditionProfile.createBinaryProfile(); + this.fastLeftAccess = isGeneric ? null : left.access(); + this.fastRightAccess = isGeneric ? null : right.access(); this.vectorNode = VectorMapBinaryInternalNode.create(resultType, argumentType); - this.scalarTypes = left instanceof RScalarVector && right instanceof RScalarVector; boolean leftVectorImpl = RVector.class.isAssignableFrom(leftClass); boolean rightVectorImpl = RVector.class.isAssignableFrom(rightClass); this.mayContainMetadata = leftVectorImpl || rightVectorImpl; this.mayFoldConstantTime = function.mayFoldConstantTime(leftClass, rightClass); + this.leftIsNAProfile = mayFoldConstantTime ? ConditionProfile.createBinaryProfile() : null; + this.rightIsNAProfile = mayFoldConstantTime ? ConditionProfile.createBinaryProfile() : null; this.mayShareLeft = left.getRType() == resultType && leftVectorImpl; this.mayShareRight = right.getRType() == resultType && rightVectorImpl; - this.argumentType = argumentType; - this.resultType = resultType; - this.maxLengthProfile = ConditionProfile.createBinaryProfile(); - // lazily create profiles only if needed to avoid unnecessary allocations this.shareLeft = mayShareLeft ? ConditionProfile.createBinaryProfile() : null; this.shareRight = mayShareRight ? ConditionProfile.createBinaryProfile() : null; this.dimensionsProfile = mayContainMetadata ? ConditionProfile.createBinaryProfile() : null; this.copyAttributes = mayContainMetadata ? CopyAttributesNodeGen.create(copyAttributes) : null; + this.maxLengthProfile = ConditionProfile.createBinaryProfile(); + this.isGeneric = isGeneric; } - public static BinaryMapNode create(BinaryMapFunctionNode function, RAbstractVector left, RAbstractVector right, RType argumentType, RType resultType, boolean copyAttributes) { - return new BinaryMapNode(function, left, right, argumentType, resultType, copyAttributes); - } - - public boolean isSupported(Object left, Object right) { - return left.getClass() == leftClass && right.getClass() == rightClass; + @Override + public boolean isSupported(RAbstractVector left, RAbstractVector right) { + return left.getClass() == leftClass && right.getClass() == rightClass && (isGeneric || fastLeftAccess.supports(left) && fastRightAccess.supports(right)); } - public Object apply(Object originalLeft, Object originalRight) { + @Override + public Object apply(RAbstractVector originalLeft, RAbstractVector originalRight) { assert isSupported(originalLeft, originalRight); RAbstractVector left = leftClass.cast(originalLeft); RAbstractVector right = rightClass.cast(originalRight); - RAbstractVector leftCast = left.castSafe(argumentType, leftIsNAProfile, false); - RAbstractVector rightCast = right.castSafe(argumentType, rightIsNAProfile, false); + function.enable(left, right); - assert leftCast != null; - assert rightCast != null; - - function.enable(leftCast, rightCast); + if (mayContainMetadata && (dimensionsProfile.profile(hasLeftDimNode.execute(left) && hasRightDimNode.execute(right)))) { + if (differentDimensions(left, right)) { + throw error(RError.Message.NON_CONFORMABLE_ARRAYS); + } + } - if (scalarTypes) { - assert left.getLength() == 1; - assert right.getLength() == 1; - return applyScalar(leftCast, rightCast); - } else { - int leftLength = leftLengthProfile.profile(left.getLength()); - int rightLength = rightLengthProfile.profile(right.getLength()); - return applyVectorized(left, leftCast, leftLength, right, rightCast, rightLength); + VectorAccess leftAccess = isGeneric ? left.slowPathAccess() : fastLeftAccess; + VectorAccess rightAccess = isGeneric ? right.slowPathAccess() : fastRightAccess; + try (SequentialIterator leftIter = leftAccess.access(left); + SequentialIterator rightIter = rightAccess.access(right)) { + RAbstractVector target = null; + int leftLength = leftLengthProfile.profile(leftAccess.getLength(leftIter)); + int rightLength = rightLengthProfile.profile(rightAccess.getLength(rightIter)); + if (seenEmpty.profile(leftLength == 0 || rightLength == 0)) { + /* + * It is safe to skip attribute handling here as they are never copied if length is + * 0 of either side. Note that dimension check still needs to be performed. + */ + return resultType.getEmpty(); + } + if (mayFoldConstantTime) { + target = function.tryFoldConstantTime(left.castSafe(argumentType, leftIsNAProfile, false), leftLength, right.castSafe(argumentType, rightIsNAProfile, false), rightLength); + } + if (target == null) { + int maxLength = maxLengthProfile.profile(leftLength >= rightLength) ? leftLength : rightLength; + + assert left.getLength() == leftLength; + assert right.getLength() == rightLength; + if (mayShareLeft && left.getRType() == resultType && shareLeft.profile(leftLength == maxLength && ((RShareable) left).isTemporary())) { + target = left; + vectorNode.execute(function, leftLength, rightLength, leftAccess, leftIter, leftAccess, leftIter, rightAccess, rightIter); + } else if (mayShareRight && right.getRType() == resultType && shareRight.profile(rightLength == maxLength && ((RShareable) right).isTemporary())) { + target = right; + vectorNode.execute(function, leftLength, rightLength, rightAccess, rightIter, leftAccess, leftIter, rightAccess, rightIter); + } else { + if (resultAccess == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + resultAccess = insert(VectorAccess.createNew(resultType)); + } + target = resultType.create(maxLength, false); + try (SequentialIterator resultIter = resultAccess.access(target)) { + vectorNode.execute(function, leftLength, rightLength, resultAccess, resultIter, leftAccess, leftIter, rightAccess, rightIter); + } + } + RBaseNode.reportWork(this, maxLength); + target.setComplete(function.isComplete()); + } + if (mayContainMetadata) { + target = copyAttributes.execute(target, left, leftLength, right, rightLength); + } + return target; } } @@ -166,353 +284,308 @@ public final class BinaryMapNode extends RBaseNode { } return false; } +} + +abstract class VectorMapBinaryInternalNode extends RBaseNode { + + private abstract static class MapBinaryIndexedAction { + public abstract void perform(BinaryMapFunctionNode action, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter); + } + + private static final MapBinaryIndexedAction LOGICAL_LOGICAL = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setLogical(resultIter, arithmetic.applyLogical(left.getLogical(leftIter), right.getLogical(rightIter))); + } + }; + private static final MapBinaryIndexedAction LOGICAL_INTEGER = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setLogical(resultIter, arithmetic.applyLogical(left.getInt(leftIter), right.getInt(rightIter))); + } + }; + private static final MapBinaryIndexedAction LOGICAL_DOUBLE = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setLogical(resultIter, arithmetic.applyLogical(left.getDouble(leftIter), right.getDouble(rightIter))); + } + }; + private static final MapBinaryIndexedAction LOGICAL_COMPLEX = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setLogical(resultIter, arithmetic.applyLogical(left.getComplex(leftIter), right.getComplex(rightIter))); + } + }; + private static final MapBinaryIndexedAction LOGICAL_CHARACTER = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setLogical(resultIter, arithmetic.applyLogical(left.getString(leftIter), right.getString(rightIter))); + } + }; + private static final MapBinaryIndexedAction LOGICAL_RAW = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setLogical(resultIter, arithmetic.applyLogical(RRuntime.raw2int(left.getRaw(leftIter)), RRuntime.raw2int(right.getRaw(rightIter)))); + } + }; + private static final MapBinaryIndexedAction RAW_RAW = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setRaw(resultIter, arithmetic.applyRaw(left.getRaw(leftIter), right.getRaw(rightIter))); + } + }; + private static final MapBinaryIndexedAction INTEGER_INTEGER = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setInt(resultIter, arithmetic.applyInteger(left.getInt(leftIter), right.getInt(rightIter))); + } + }; + private static final MapBinaryIndexedAction DOUBLE_INTEGER = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setDouble(resultIter, arithmetic.applyDouble(left.getInt(leftIter), right.getInt(rightIter))); + } + }; + private static final MapBinaryIndexedAction DOUBLE = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setDouble(resultIter, arithmetic.applyDouble(left.getDouble(leftIter), right.getDouble(rightIter))); + } + }; + private static final MapBinaryIndexedAction COMPLEX = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + RComplex value = arithmetic.applyComplex(left.getComplex(leftIter), right.getComplex(rightIter)); + result.setComplex(resultIter, value.getRealPart(), value.getImaginaryPart()); + } + }; + private static final MapBinaryIndexedAction CHARACTER = new MapBinaryIndexedAction() { + @Override + public void perform(BinaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + result.setString(resultIter, arithmetic.applyCharacter(left.getString(leftIter), right.getString(rightIter))); + } + }; + + private final MapBinaryIndexedAction indexedAction; + + protected VectorMapBinaryInternalNode(RType resultType, RType argumentType) { + this.indexedAction = createIndexedAction(resultType, argumentType); + } - private Object applyScalar(RAbstractVector left, RAbstractVector right) { - switch (argumentType) { + public static VectorMapBinaryInternalNode create(RType resultType, RType argumentType) { + return VectorMapBinaryInternalNodeGen.create(resultType, argumentType); + } + + private static MapBinaryIndexedAction createIndexedAction(RType resultType, RType argumentType) { + switch (resultType) { case Raw: - byte leftValueRaw = ((RAbstractRawVector) left).getRawDataAt(0); - byte rightValueRaw = ((RAbstractRawVector) right).getRawDataAt(0); - switch (resultType) { - case Raw: - return RRaw.valueOf(function.applyRaw(leftValueRaw, rightValueRaw)); - case Logical: - return function.applyLogical(RRuntime.raw2int(leftValueRaw), RRuntime.raw2int(rightValueRaw)); - default: - throw RInternalError.shouldNotReachHere(); - } + assert argumentType == RType.Raw; + return RAW_RAW; case Logical: - byte leftValueLogical = ((RAbstractLogicalVector) left).getDataAt(0); - byte rightValueLogical = ((RAbstractLogicalVector) right).getDataAt(0); - return function.applyLogical(leftValueLogical, rightValueLogical); - case Integer: - int leftValueInt = ((RAbstractIntVector) left).getDataAt(0); - int rightValueInt = ((RAbstractIntVector) right).getDataAt(0); - switch (resultType) { + switch (argumentType) { + case Raw: + return LOGICAL_RAW; case Logical: - return function.applyLogical(leftValueInt, rightValueInt); + return LOGICAL_LOGICAL; case Integer: - return function.applyInteger(leftValueInt, rightValueInt); + return LOGICAL_INTEGER; case Double: - return function.applyDouble(leftValueInt, rightValueInt); + return LOGICAL_DOUBLE; + case Complex: + return LOGICAL_COMPLEX; + case Character: + return LOGICAL_CHARACTER; default: throw RInternalError.shouldNotReachHere(); } + case Integer: + assert argumentType == RType.Integer; + return INTEGER_INTEGER; case Double: - double leftValueDouble = ((RAbstractDoubleVector) left).getDataAt(0); - double rightValueDouble = ((RAbstractDoubleVector) right).getDataAt(0); - switch (resultType) { - case Logical: - return function.applyLogical(leftValueDouble, rightValueDouble); + switch (argumentType) { + case Integer: + return DOUBLE_INTEGER; case Double: - return function.applyDouble(leftValueDouble, rightValueDouble); + return DOUBLE; default: throw RInternalError.shouldNotReachHere(); } case Complex: - RComplex leftValueComplex = ((RAbstractComplexVector) left).getDataAt(0); - RComplex rightValueComplex = ((RAbstractComplexVector) right).getDataAt(0); - switch (resultType) { - case Logical: - return function.applyLogical(leftValueComplex, rightValueComplex); - case Complex: - return function.applyComplex(leftValueComplex, rightValueComplex); - default: - throw RInternalError.shouldNotReachHere(); - } + assert argumentType == RType.Complex; + return COMPLEX; case Character: - String leftValueString = ((RAbstractStringVector) left).getDataAt(0); - String rightValueString = ((RAbstractStringVector) right).getDataAt(0); - switch (resultType) { - case Logical: - return function.applyLogical(leftValueString, rightValueString); - default: - throw RInternalError.shouldNotReachHere(); - } + assert argumentType == RType.Character; + return CHARACTER; default: throw RInternalError.shouldNotReachHere(); } } - private Object applyVectorized(RAbstractVector left, RAbstractVector leftCast, int leftLength, RAbstractVector right, RAbstractVector rightCast, int rightLength) { - if (mayContainMetadata && (dimensionsProfile.profile(hasLeftDimNode.execute(left) && hasRightDimNode.execute(right)))) { - if (differentDimensions(left, right)) { - throw error(RError.Message.NON_CONFORMABLE_ARRAYS); - } - } - - if (seenEmpty.profile(leftLength == 0 || rightLength == 0)) { - /* - * It is safe to skip attribute handling here as they are never copied if length is 0 of - * either side. Note that dimension check still needs to be performed. - */ - return resultType.getEmpty(); - } + public abstract void execute(BinaryMapFunctionNode node, int leftLength, int rightLength, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter); - RAbstractVector target = null; - if (mayFoldConstantTime) { - target = function.tryFoldConstantTime(leftCast, leftLength, rightCast, rightLength); - } - if (target == null) { - int maxLength = maxLengthProfile.profile(leftLength >= rightLength) ? leftLength : rightLength; - RVector<?> targetVec = createOrShareVector(leftLength, left, rightLength, right, maxLength); - target = targetVec; - - assert left.getLength() == leftLength; - assert right.getLength() == rightLength; - assert leftCast.getRType() == argumentType; - assert rightCast.getRType() == argumentType; - - vectorNode.execute(function, targetVec, leftCast, leftLength, rightCast, rightLength); - RBaseNode.reportWork(this, maxLength); - target.setComplete(function.isComplete()); + @Specialization(guards = {"leftLength == 1", "rightLength == 1"}) + protected void doScalarScalar(BinaryMapFunctionNode node, @SuppressWarnings("unused") int leftLength, @SuppressWarnings("unused") int rightLength, VectorAccess result, + SequentialIterator resultIter, VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter) { + left.next(leftIter); + right.next(rightIter); + if (result != right && result != left) { + result.next(resultIter); } - if (mayContainMetadata) { - target = copyAttributes.execute(target, left, leftLength, right, rightLength); - } - return target; + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); } - private RVector<?> createOrShareVector(int leftLength, RAbstractVector left, int rightLength, RAbstractVector right, int maxLength) { - if (mayShareLeft && left.getRType() == resultType && shareLeft.profile(leftLength == maxLength && ((RShareable) left).isTemporary()) && left instanceof RVector<?>) { - return (RVector<?>) left; - } - if (mayShareRight && right.getRType() == resultType && shareRight.profile(rightLength == maxLength && ((RShareable) right).isTemporary()) && right instanceof RVector<?>) { - return (RVector<?>) right; + @Specialization(replaces = "doScalarScalar", guards = {"leftLength == 1"}) + protected void doScalarVector(BinaryMapFunctionNode node, @SuppressWarnings("unused") int leftLength, int rightLength, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter, + @Cached("createCountingProfile()") LoopConditionProfile profile) { + profile.profileCounted(rightLength); + left.next(leftIter); + while (profile.inject(right.next(rightIter))) { + if (result != right && result != left) { + result.next(resultIter); + } + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); } - return resultType.create(maxLength, false); } - @ImportStatic(Utils.class) - protected abstract static class VectorMapBinaryInternalNode extends RBaseNode { - - private static final MapBinaryIndexedAction<Byte> LOGICAL_LOGICAL = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyLogical(leftVal, rightVal); - private static final MapBinaryIndexedAction<Integer> LOGICAL_INTEGER = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyLogical(leftVal, rightVal); - private static final MapBinaryIndexedAction<Double> LOGICAL_DOUBLE = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyLogical(leftVal, rightVal); - private static final MapBinaryIndexedAction<RComplex> LOGICAL_COMPLEX = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyLogical(leftVal, rightVal); - private static final MapBinaryIndexedAction<String> LOGICAL_CHARACTER = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyLogical(leftVal, rightVal); - private static final MapBinaryIndexedAction<Byte> LOGICAL_RAW = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyLogical(RRuntime.raw2int(leftVal), RRuntime.raw2int(rightVal)); - private static final MapBinaryIndexedAction<Byte> RAW_RAW = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyRaw(leftVal, rightVal); - private static final MapBinaryIndexedAction<Integer> INTEGER_INTEGER = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyInteger(leftVal, rightVal); - private static final MapBinaryIndexedAction<Integer> DOUBLE_INTEGER = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyDouble(leftVal, rightVal); - private static final MapBinaryIndexedAction<Double> DOUBLE = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyDouble(leftVal, rightVal); - private static final MapBinaryIndexedAction<RComplex> COMPLEX = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyComplex(leftVal, rightVal); - private static final MapBinaryIndexedAction<String> CHARACTER = // - (arithmetic, leftVal, rightVal) -> arithmetic.applyCharacter(leftVal, rightVal); - - private final MapBinaryIndexedAction<Object> indexedAction; - - @Child private GetDataStore getTargetDataStore = GetDataStore.create(); - @Child private SetDataAt targetSetDataAt; - - @SuppressWarnings("unchecked") - protected VectorMapBinaryInternalNode(RType resultType, RType argumentType) { - this.indexedAction = (MapBinaryIndexedAction<Object>) createIndexedAction(resultType, argumentType); - this.targetSetDataAt = Utils.createSetDataAtNode(resultType); + @Specialization(replaces = "doScalarScalar", guards = {"rightLength == 1"}) + protected void doVectorScalar(BinaryMapFunctionNode node, int leftLength, @SuppressWarnings("unused") int rightLength, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter, + @Cached("createCountingProfile()") LoopConditionProfile profile) { + profile.profileCounted(leftLength); + right.next(rightIter); + while (profile.inject(left.next(leftIter))) { + if (result != left && result != right) { + result.next(resultIter); + } + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); } + } - public static VectorMapBinaryInternalNode create(RType resultType, RType argumentType) { - return VectorMapBinaryInternalNodeGen.create(resultType, argumentType); + @Specialization(guards = {"leftLength == rightLength"}) + protected void doSameLength(BinaryMapFunctionNode node, int leftLength, @SuppressWarnings("unused") int rightLength, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter, + @Cached("createCountingProfile()") LoopConditionProfile profile) { + profile.profileCounted(leftLength); + while (profile.inject(left.next(leftIter))) { + right.next(rightIter); + if (result != left && result != right) { + result.next(resultIter); + } + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); } + } - private static MapBinaryIndexedAction<?> createIndexedAction(RType resultType, RType argumentType) { - switch (resultType) { - case Raw: - assert argumentType == RType.Raw; - return RAW_RAW; - case Logical: - switch (argumentType) { - case Raw: - return LOGICAL_RAW; - case Logical: - return LOGICAL_LOGICAL; - case Integer: - return LOGICAL_INTEGER; - case Double: - return LOGICAL_DOUBLE; - case Complex: - return LOGICAL_COMPLEX; - case Character: - return LOGICAL_CHARACTER; - default: - throw RInternalError.shouldNotReachHere(); + @Specialization(replaces = {"doVectorScalar", "doScalarVector", "doSameLength"}, guards = {"leftLength >= rightLength"}) + protected void doMultiplesLeft(BinaryMapFunctionNode node, int leftLength, int rightLength, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter, + @Cached("createCountingProfile()") LoopConditionProfile leftProfile, + @Cached("createCountingProfile()") LoopConditionProfile rightProfile, + @Cached("createBinaryProfile()") ConditionProfile smallRemainderProfile) { + assert result != right; + leftProfile.profileCounted(leftLength); + rightProfile.profileCounted(rightLength); + while (leftProfile.inject(leftIter.getIndex() + 1 < leftLength)) { + right.reset(rightIter); + if (smallRemainderProfile.profile((leftLength - leftIter.getIndex() - 1) >= rightLength)) { + // we need at least rightLength more elements + while (rightProfile.inject(right.next(rightIter)) && leftProfile.inject(left.next(leftIter))) { + if (result != left) { + result.next(resultIter); } - case Integer: - assert argumentType == RType.Integer; - return INTEGER_INTEGER; - case Double: - switch (argumentType) { - case Integer: - return DOUBLE_INTEGER; - case Double: - return DOUBLE; - default: - throw RInternalError.shouldNotReachHere(); + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); + } + } else { + while (rightProfile.inject(right.next(rightIter)) && leftProfile.inject(left.next(leftIter))) { + if (result != left) { + result.next(resultIter); } - case Complex: - assert argumentType == RType.Complex; - return COMPLEX; - case Character: - assert argumentType == RType.Character; - return CHARACTER; - default: - throw RInternalError.shouldNotReachHere(); - } - } - - public abstract void execute(BinaryMapFunctionNode node, RVector<?> store, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength); - - @Specialization(guards = {"leftLength == 1", "rightLength == 1"}) - @SuppressWarnings("unused") - protected void doScalarScalar(BinaryMapFunctionNode node, RVector<?> result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength, - @Cached("createIterator()") VectorIterator.Generic leftIterator, - @Cached("createIterator()") VectorIterator.Generic rightIterator) { - Object itLeft = leftIterator.init(left); - Object itRight = rightIterator.init(right); - Object value = indexedAction.perform(node, leftIterator.next(left, itLeft), rightIterator.next(right, itRight)); - targetSetDataAt.setDataAtAsObject(result, getTargetDataStore.execute(result), 0, value); - } - - @Specialization(replaces = "doScalarScalar", guards = {"leftLength == 1"}) - @SuppressWarnings("unused") - protected void doScalarVector(BinaryMapFunctionNode node, RVector<?> result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength, - @Cached("createIterator()") VectorIterator.Generic leftIterator, - @Cached("createIterator()") VectorIterator.Generic rightIterator, - @Cached("createCountingProfile()") LoopConditionProfile profile) { - profile.profileCounted(rightLength); - Object itLeft = leftIterator.init(left); - Object itRight = rightIterator.init(right); - Object resultStore = getTargetDataStore.execute(result); - Object leftValue = leftIterator.next(left, itLeft); - for (int i = 0; profile.inject(i < rightLength); ++i) { - Object value = indexedAction.perform(node, leftValue, rightIterator.next(right, itRight)); - targetSetDataAt.setDataAtAsObject(result, resultStore, i, value); + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); + } + RError.warning(this, RError.Message.LENGTH_NOT_MULTI); } } + } - @Specialization(replaces = "doScalarScalar", guards = {"rightLength == 1"}) - @SuppressWarnings("unused") - protected void doVectorScalar(BinaryMapFunctionNode node, RVector<?> result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength, - @Cached("createIterator()") VectorIterator.Generic leftIterator, - @Cached("createIterator()") VectorIterator.Generic rightIterator, - @Cached("createCountingProfile()") LoopConditionProfile profile) { - profile.profileCounted(leftLength); - Object itLeft = leftIterator.init(left); - Object itRight = rightIterator.init(right); - Object resultStore = getTargetDataStore.execute(result); - Object rightValue = rightIterator.next(right, itRight); - for (int i = 0; profile.inject(i < leftLength); ++i) { - Object value = indexedAction.perform(node, leftIterator.next(left, itLeft), rightValue); - targetSetDataAt.setDataAtAsObject(result, resultStore, i, value); + @Specialization(replaces = {"doVectorScalar", "doScalarVector", "doSameLength"}, guards = {"rightLength >= leftLength"}) + protected void doMultiplesRight(BinaryMapFunctionNode node, int leftLength, int rightLength, VectorAccess result, SequentialIterator resultIter, + VectorAccess left, SequentialIterator leftIter, VectorAccess right, SequentialIterator rightIter, + @Cached("createCountingProfile()") LoopConditionProfile leftProfile, + @Cached("createCountingProfile()") LoopConditionProfile rightProfile, + @Cached("createBinaryProfile()") ConditionProfile smallRemainderProfile) { + assert result != left; + leftProfile.profileCounted(leftLength); + rightProfile.profileCounted(rightLength); + while (rightProfile.inject(rightIter.getIndex() + 1 < rightLength)) { + left.reset(leftIter); + if (smallRemainderProfile.profile((rightLength - rightIter.getIndex() - 1) >= leftLength)) { + // we need at least leftLength more elements + while (leftProfile.inject(left.next(leftIter)) && rightProfile.inject(right.next(rightIter))) { + if (result != right) { + result.next(resultIter); + } + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); + } + } else { + while (leftProfile.inject(left.next(leftIter)) && rightProfile.inject(right.next(rightIter))) { + if (result != right) { + result.next(resultIter); + } + indexedAction.perform(node, result, resultIter, left, leftIter, right, rightIter); + } + RError.warning(this, RError.Message.LENGTH_NOT_MULTI); } } + } +} - @Specialization(guards = {"leftLength == rightLength"}) - @SuppressWarnings("unused") - protected void doSameLength(BinaryMapFunctionNode node, RVector<?> result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength, - @Cached("createIterator()") VectorIterator.Generic leftIterator, - @Cached("createIterator()") VectorIterator.Generic rightIterator, - @Cached("createCountingProfile()") LoopConditionProfile profile) { - profile.profileCounted(leftLength); - Object itLeft = leftIterator.init(left); - Object itRight = rightIterator.init(right); - Object resultStore = getTargetDataStore.execute(result); - for (int i = 0; profile.inject(i < leftLength); ++i) { - Object value = indexedAction.perform(node, leftIterator.next(left, itLeft), rightIterator.next(right, itRight)); - targetSetDataAt.setDataAtAsObject(result, resultStore, i, value); - } - } +/** + * Implements a binary map operation that maps two vectors into a single result vector of the + * maximum size of both vectors. Vectors with smaller length are repeated. The actual implementation + * is provided using a {@link BinaryMapFunctionNode}. + * + * The implementation tries to share input vectors if they are implementing {@link RShareable}. + */ +public abstract class BinaryMapNode extends RBaseNode { - protected static boolean multiplesMinMax(int min, int max) { - return max % min == 0; - } + @Child protected BinaryMapFunctionNode function; + protected final Class<? extends RAbstractVector> leftClass; + protected final Class<? extends RAbstractVector> rightClass; + protected final RType argumentType; + protected final RType resultType; - @Specialization(replaces = {"doVectorScalar", "doScalarVector", "doSameLength"}, guards = {"multiplesMinMax(leftLength, rightLength)"}) - protected void doMultiplesLeft(BinaryMapFunctionNode node, RVector<?> result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength, - @Cached("createIteratorWrapAround()") VectorIterator.Generic leftIterator, - @Cached("createIterator()") VectorIterator.Generic rightIterator, - @Cached("createCountingProfile()") LoopConditionProfile leftProfile, - @Cached("createCountingProfile()") LoopConditionProfile rightProfile) { - int j = 0; - rightProfile.profileCounted(rightLength / leftLength); - Object itLeft = leftIterator.init(left); - Object itRight = rightIterator.init(right); - Object resultStore = getTargetDataStore.execute(result); - while (rightProfile.inject(j < rightLength)) { - leftProfile.profileCounted(leftLength); - for (int k = 0; leftProfile.inject(k < leftLength); k++) { - Object value = indexedAction.perform(node, leftIterator.next(left, itLeft), rightIterator.next(right, itRight)); - targetSetDataAt.setDataAtAsObject(result, resultStore, j, value); - j++; - } - } - } + protected BinaryMapNode(BinaryMapFunctionNode function, RAbstractVector left, RAbstractVector right, RType argumentType, RType resultType) { + this.function = function; + this.leftClass = left.getClass(); + this.rightClass = right.getClass(); + this.argumentType = argumentType; + this.resultType = resultType; + } - @Specialization(replaces = {"doVectorScalar", "doScalarVector", "doSameLength"}, guards = {"multiplesMinMax(rightLength, leftLength)"}) - protected void doMultiplesRight(BinaryMapFunctionNode node, RVector<?> result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength, - @Cached("createIterator()") VectorIterator.Generic leftIterator, - @Cached("createIteratorWrapAround()") VectorIterator.Generic rightIterator, - @Cached("createCountingProfile()") LoopConditionProfile leftProfile, - @Cached("createCountingProfile()") LoopConditionProfile rightProfile) { - int j = 0; - leftProfile.profileCounted(leftLength / rightLength); - Object itLeft = leftIterator.init(left); - Object itRight = rightIterator.init(right); - Object resultStore = getTargetDataStore.execute(result); - while (leftProfile.inject(j < leftLength)) { - rightProfile.profileCounted(rightLength); - for (int k = 0; rightProfile.inject(k < rightLength); k++) { - Object value = indexedAction.perform(node, leftIterator.next(left, itLeft), rightIterator.next(right, itRight)); - targetSetDataAt.setDataAtAsObject(result, resultStore, j, value); - j++; - } - } + public static BinaryMapNode create(BinaryMapFunctionNode function, RAbstractVector left, RAbstractVector right, RType argumentType, RType resultType, boolean copyAttributes, boolean isGeneric) { + if (left instanceof RScalarVector && right instanceof RScalarVector) { + return new BinaryMapScalarNode(function, left, right, argumentType, resultType); + } else { + return new BinaryMapVectorNode(function, left, right, argumentType, resultType, copyAttributes, isGeneric); } + } - protected static boolean multiples(int leftLength, int rightLength) { - int min; - int max; - if (leftLength >= rightLength) { - min = rightLength; - max = leftLength; - } else { - min = leftLength; - max = rightLength; - } - return max % min == 0; - } + public abstract boolean isSupported(RAbstractVector left, RAbstractVector right); - @Specialization(guards = {"!multiples(leftLength, rightLength)"}) - protected void doNoMultiples(BinaryMapFunctionNode node, RVector<?> result, RAbstractVector left, int leftLength, RAbstractVector right, int rightLength, - @Cached("createIteratorWrapAround()") VectorIterator.Generic leftIterator, - @Cached("createIteratorWrapAround()") VectorIterator.Generic rightIterator, - @Cached("createCountingProfile()") LoopConditionProfile profile, - @Cached("createBinaryProfile()") ConditionProfile leftIncModProfile, - @Cached("createBinaryProfile()") ConditionProfile rightIncModProfile) { - int max = Math.max(leftLength, rightLength); - profile.profileCounted(max); - Object itLeft = leftIterator.init(left); - Object itRight = rightIterator.init(right); - Object resultStore = getTargetDataStore.execute(result); - for (int i = 0; profile.inject(i < max); ++i) { - Object value = indexedAction.perform(node, leftIterator.next(left, itLeft), rightIterator.next(right, itRight)); - targetSetDataAt.setDataAtAsObject(result, resultStore, i, value); - } - RError.warning(this, RError.Message.LENGTH_NOT_MULTI); - } + public abstract Object apply(RAbstractVector originalLeft, RAbstractVector originalRight); - private interface MapBinaryIndexedAction<V> { - Object perform(BinaryMapFunctionNode action, V left, V right); - } - } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/UnaryMapNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/UnaryMapNode.java index aaa72efdec80a239329d21e901f04e54825e7a27..1b997142d4066f3258265d10921a4ce32d70d39f 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/UnaryMapNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/primitive/UnaryMapNode.java @@ -25,7 +25,6 @@ package com.oracle.truffle.r.nodes.primitive; import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; -import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; @@ -35,7 +34,6 @@ import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAt import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode; -import com.oracle.truffle.r.nodes.primitive.UnaryMapNodeFactory.MapUnaryVectorInternalNodeGen; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RType; @@ -43,121 +41,130 @@ import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RScalarVector; import com.oracle.truffle.r.runtime.data.RShareable; import com.oracle.truffle.r.runtime.data.RVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -import com.oracle.truffle.r.runtime.data.nodes.GetDataStore; -import com.oracle.truffle.r.runtime.data.nodes.SetDataAt; -import com.oracle.truffle.r.runtime.data.nodes.VectorIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; import com.oracle.truffle.r.runtime.nodes.RBaseNode; -public final class UnaryMapNode extends RBaseNode { +final class UnaryMapScalarNode extends UnaryMapNode { + + @Child private VectorAccess operandAccess; + + UnaryMapScalarNode(UnaryMapFunctionNode scalarNode, RAbstractVector operand, RType argumentType, RType resultType) { + super(scalarNode, operand, argumentType, resultType); + this.operandAccess = operand.access(); + } + + @Override + public boolean isSupported(RAbstractVector operand) { + return operandAccess.supports(operand); + } + + @Override + public Object apply(RAbstractVector operand) { + assert isSupported(operand); + + function.enable(operand); + assert operand.getLength() == 1; + + try (RandomIterator iter = operandAccess.randomAccess(operand)) { + switch (argumentType) { + case Logical: + return function.applyLogical(operandAccess.getLogical(iter, 0)); + case Integer: + return function.applyInteger(operandAccess.getInt(iter, 0)); + case Double: + return function.applyDouble(operandAccess.getDouble(iter, 0)); + case Complex: + switch (resultType) { + case Double: + return function.applyDouble(operandAccess.getComplex(iter, 0)); + case Complex: + return function.applyComplex(operandAccess.getComplex(iter, 0)); + default: + throw RInternalError.shouldNotReachHere(); + } + default: + throw RInternalError.shouldNotReachHere(); + } + } + } +} + +final class UnaryMapVectorNode extends UnaryMapNode { - @Child private UnaryMapFunctionNode scalarNode; @Child private MapUnaryVectorInternalNode vectorNode; @Child private GetDimAttributeNode getDimNode; @Child private SetDimAttributeNode setDimNode; @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); + @Child private VectorAccess fastOperandAccess; + @Child private VectorAccess resultAccess; // profiles - private final Class<? extends RAbstractVector> operandClass; private final VectorLengthProfile operandLengthProfile = VectorLengthProfile.create(); - private final ConditionProfile operandIsNAProfile = ConditionProfile.createBinaryProfile(); private final BranchProfile hasAttributesProfile; private final ConditionProfile shareOperand; // compile-time optimization flags - private final boolean scalarType; private final boolean mayContainMetadata; private final boolean mayFoldConstantTime; private final boolean mayShareOperand; + private final boolean isGeneric; - private UnaryMapNode(UnaryMapFunctionNode scalarNode, RAbstractVector operand, RType argumentType, RType resultType) { - this.scalarNode = scalarNode; + UnaryMapVectorNode(UnaryMapFunctionNode scalarNode, RAbstractVector operand, RType argumentType, RType resultType, boolean isGeneric) { + super(scalarNode, operand, argumentType, resultType); + this.fastOperandAccess = isGeneric ? null : operand.access(); this.vectorNode = MapUnaryVectorInternalNode.create(resultType, argumentType); - this.operandClass = operand.getClass(); - this.scalarType = operand instanceof RScalarVector; boolean operandVector = operand instanceof RVector; this.mayContainMetadata = operandVector; - this.mayFoldConstantTime = scalarNode.mayFoldConstantTime(operandClass); + this.mayFoldConstantTime = argumentType == operand.getRType() && scalarNode.mayFoldConstantTime(operandClass); this.mayShareOperand = operandVector; + this.isGeneric = isGeneric; // lazily create profiles only if needed to avoid unnecessary allocations - this.shareOperand = operandVector ? ConditionProfile.createBinaryProfile() : null; + this.shareOperand = mayShareOperand ? ConditionProfile.createBinaryProfile() : null; this.hasAttributesProfile = mayContainMetadata ? BranchProfile.create() : null; - } - - public static UnaryMapNode create(UnaryMapFunctionNode scalarNode, RAbstractVector operand, RType argumentType, RType resultType) { - return new UnaryMapNode(scalarNode, operand, argumentType, resultType); - } - public Class<? extends RAbstractVector> getOperandClass() { - return operandClass; } - public RType getArgumentType() { - return vectorNode.getArgumentType(); + @Override + public boolean isSupported(RAbstractVector operand) { + return operand.getClass() == operandClass && (isGeneric || fastOperandAccess.supports(operand)); } - public RType getResultType() { - return vectorNode.getResultType(); - } - - public boolean isSupported(Object operand) { - return operand.getClass() == operandClass; - } - - public Object apply(Object originalOperand) { + @Override + public Object apply(RAbstractVector originalOperand) { assert isSupported(originalOperand); RAbstractVector operand = operandClass.cast(originalOperand); + function.enable(operand); + int operandLength = operandLengthProfile.profile(operand.getLength()); - RAbstractVector operandCast = operand.castSafe(getArgumentType(), operandIsNAProfile); - - scalarNode.enable(operandCast); - if (scalarType) { - assert operand.getLength() == 1; - return scalarOperation(operandCast); - } else { - int operandLength = operandLengthProfile.profile(operand.getLength()); - return vectorOperation(operand, operandCast, operandLength); - } - } - - private Object scalarOperation(RAbstractVector operand) { - switch (getArgumentType()) { - case Logical: - return scalarNode.applyLogical(((RAbstractLogicalVector) operand).getDataAt(0)); - case Integer: - return scalarNode.applyInteger(((RAbstractIntVector) operand).getDataAt(0)); - case Double: - return scalarNode.applyDouble(((RAbstractDoubleVector) operand).getDataAt(0)); - case Complex: - switch (getResultType()) { - case Double: - return scalarNode.applyDouble(((RAbstractComplexVector) operand).getDataAt(0)); - case Complex: - return scalarNode.applyComplex(((RAbstractComplexVector) operand).getDataAt(0)); - default: - throw RInternalError.shouldNotReachHere(); - } - default: - throw RInternalError.shouldNotReachHere(); - } - } - - private Object vectorOperation(RAbstractVector operand, RAbstractVector operandCast, int operandLength) { RAbstractVector target = null; if (mayFoldConstantTime) { - target = scalarNode.tryFoldConstantTime(operandCast, operandLength); + target = function.tryFoldConstantTime(operand, operandLength); } if (target == null) { - RVector<?> targetVec = createOrShareVector(operandLength, operand); - target = targetVec; - vectorNode.apply(scalarNode, targetVec, operandCast, operandLength); + VectorAccess operandAccess = isGeneric ? operand.slowPathAccess() : fastOperandAccess; + try (SequentialIterator operandIter = operandAccess.access(operand)) { + if (mayShareOperand && operand.getRType() == resultType && shareOperand.profile(((RShareable) operand).isTemporary())) { + target = operand; + vectorNode.execute(function, operandLength, operandAccess, operandIter, operandAccess, operandIter); + } else { + if (resultAccess == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + target = resultType.create(operandLength, false); + resultAccess = insert(target.access()); + } else { + target = resultType.create(operandLength, false); + } + try (SequentialIterator resultIter = resultAccess.access(target)) { + vectorNode.execute(function, operandLength, resultAccess, resultIter, operandAccess, operandIter); + } + } + } RBaseNode.reportWork(this, operandLength); - target.setComplete(scalarNode.isComplete()); + target.setComplete(function.isComplete()); } if (mayContainMetadata) { target = handleMetadata(target, operand); @@ -165,14 +172,6 @@ public final class UnaryMapNode extends RBaseNode { return target; } - private RVector<?> createOrShareVector(int operandLength, RAbstractVector operand) { - RType resultType = getResultType(); - if (mayShareOperand && operand.getRType() == resultType && shareOperand.profile(((RShareable) operand).isTemporary()) && operand instanceof RVector<?>) { - return (RVector<?>) operand; - } - return resultType.create(operandLength, false); - } - private RAbstractVector handleMetadata(RAbstractVector target, RAbstractVector operand) { RAbstractVector result = target; if (containsMetadata(operand) && operand != target) { @@ -213,107 +212,141 @@ public final class UnaryMapNode extends RBaseNode { result.copyRegAttributesFrom(attributeSource); result.copyNamesFrom(attributeSource); } +} - @ImportStatic(Utils.class) - protected abstract static class MapUnaryVectorInternalNode extends RBaseNode { - - private static final MapIndexedAction<Byte> LOGICAL = (arithmetic, value) -> arithmetic.applyLogical(value); - private static final MapIndexedAction<Integer> INTEGER = (arithmetic, value) -> arithmetic.applyInteger(value); - private static final MapIndexedAction<Double> DOUBLE = (arithmetic, value) -> arithmetic.applyDouble(value); - private static final MapIndexedAction<RComplex> COMPLEX = (arithmetic, value) -> arithmetic.applyComplex(value); - private static final MapIndexedAction<RComplex> DOUBLE_COMPLEX = (arithmetic, value) -> arithmetic.applyDouble(value); - private static final MapIndexedAction<String> CHARACTER = (arithmetic, value) -> arithmetic.applyCharacter(value); - - private final MapIndexedAction<Object> indexedAction; - private final RType argumentType; - private final RType resultType; - - @Child private GetDataStore getTargetDataStore = GetDataStore.create(); - @Child private SetDataAt targetSetDataAt; - - @SuppressWarnings("unchecked") - protected MapUnaryVectorInternalNode(RType resultType, RType argumentType) { - this.indexedAction = (MapIndexedAction<Object>) createIndexedAction(resultType, argumentType); - this.argumentType = argumentType; - this.resultType = resultType; - this.targetSetDataAt = Utils.createSetDataAtNode(resultType); - } +abstract class MapUnaryVectorInternalNode extends RBaseNode { - public RType getArgumentType() { - return argumentType; - } + private abstract static class MapIndexedAction { + public abstract void perform(UnaryMapFunctionNode action, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter); + } - public RType getResultType() { - return resultType; + private static final MapIndexedAction LOGICAL = new MapIndexedAction() { + @Override + public void perform(UnaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter) { + result.setLogical(resultIter, arithmetic.applyLogical(operand.getLogical(operandIter))); } - - public static MapUnaryVectorInternalNode create(RType resultType, RType argumentType) { - return MapUnaryVectorInternalNodeGen.create(resultType, argumentType); + }; + private static final MapIndexedAction INTEGER = new MapIndexedAction() { + @Override + public void perform(UnaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter) { + result.setInt(resultIter, arithmetic.applyInteger(operand.getInt(operandIter))); } - - private static MapIndexedAction<?> createIndexedAction(RType resultType, RType argumentType) { - switch (argumentType) { - case Logical: - return LOGICAL; - case Integer: - switch (resultType) { - case Integer: - return INTEGER; - case Double: - return DOUBLE; - default: - throw RInternalError.shouldNotReachHere(); - } - case Double: - return DOUBLE; - case Complex: - switch (resultType) { - case Double: - return DOUBLE_COMPLEX; - case Complex: - return COMPLEX; - default: - throw RInternalError.shouldNotReachHere(); - } - case Character: - return CHARACTER; - default: - throw RInternalError.shouldNotReachHere(); - } + }; + private static final MapIndexedAction DOUBLE = new MapIndexedAction() { + @Override + public void perform(UnaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter) { + result.setDouble(resultIter, arithmetic.applyDouble(operand.getDouble(operandIter))); + } + }; + private static final MapIndexedAction COMPLEX = new MapIndexedAction() { + @Override + public void perform(UnaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter) { + RComplex value = arithmetic.applyComplex(operand.getComplex(operandIter)); + result.setComplex(resultIter, value.getRealPart(), value.getImaginaryPart()); + } + }; + private static final MapIndexedAction DOUBLE_COMPLEX = new MapIndexedAction() { + @Override + public void perform(UnaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter) { + result.setDouble(resultIter, arithmetic.applyDouble(operand.getComplex(operandIter))); } + }; + private static final MapIndexedAction CHARACTER = new MapIndexedAction() { + @Override + public void perform(UnaryMapFunctionNode arithmetic, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter) { + result.setString(resultIter, arithmetic.applyCharacter(operand.getString(operandIter))); + } + }; + + private final MapIndexedAction indexedAction; + + protected MapUnaryVectorInternalNode(RType resultType, RType argumentType) { + this.indexedAction = createIndexedAction(resultType, argumentType); + } + + public static MapUnaryVectorInternalNode create(RType resultType, RType argumentType) { + return MapUnaryVectorInternalNodeGen.create(resultType, argumentType); + } - private void apply(UnaryMapFunctionNode scalarAction, RVector<?> target, RAbstractVector operand, int operandLength) { - assert operand.getLength() == operandLength; - assert operand.getRType() == argumentType; - executeInternal(scalarAction, target, operand, operandLength); + private static MapIndexedAction createIndexedAction(RType resultType, RType argumentType) { + switch (argumentType) { + case Logical: + return LOGICAL; + case Integer: + switch (resultType) { + case Integer: + return INTEGER; + case Double: + return DOUBLE; + default: + throw RInternalError.shouldNotReachHere(); + } + case Double: + return DOUBLE; + case Complex: + switch (resultType) { + case Double: + return DOUBLE_COMPLEX; + case Complex: + return COMPLEX; + default: + throw RInternalError.shouldNotReachHere(); + } + case Character: + return CHARACTER; + default: + throw RInternalError.shouldNotReachHere(); } + } - protected abstract void executeInternal(UnaryMapFunctionNode node, Object store, RAbstractVector operand, int operandLength); + protected abstract void execute(UnaryMapFunctionNode node, int operandLength, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter); - @Specialization(guards = {"operandLength == 1"}) - protected void doScalar(UnaryMapFunctionNode node, RVector<?> target, RAbstractVector operand, int operandLength, - @Cached("createIterator()") VectorIterator.Generic iterator) { - Object it = iterator.init(operand); - Object targetStore = getTargetDataStore.execute(target); - Object value = iterator.next(operand, it); - targetSetDataAt.setDataAtAsObject(target, targetStore, 0, indexedAction.perform(node, value)); + @Specialization(guards = {"operandLength == 1"}) + protected void doScalar(UnaryMapFunctionNode node, @SuppressWarnings("unused") int operandLength, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, + SequentialIterator operandIter) { + operand.next(operandIter); + if (result != operand) { + result.next(resultIter); } + indexedAction.perform(node, result, resultIter, operand, operandIter); + } - @Specialization(replaces = "doScalar") - protected void doScalarVector(UnaryMapFunctionNode node, RVector<?> target, RAbstractVector operand, int operandLength, - @Cached("createIterator()") VectorIterator.Generic iterator, - @Cached("createCountingProfile()") LoopConditionProfile profile) { - Object targetStore = getTargetDataStore.execute(target); - Object it = iterator.init(operand); - profile.profileCounted(operandLength); - for (int i = 0; profile.inject(i < operandLength); ++i) { - Object value = indexedAction.perform(node, iterator.next(operand, it)); - targetSetDataAt.setDataAtAsObject(target, targetStore, i, value); + @Specialization(replaces = "doScalar") + protected void doScalarVector(UnaryMapFunctionNode node, int operandLength, VectorAccess result, SequentialIterator resultIter, VectorAccess operand, SequentialIterator operandIter, + @Cached("createCountingProfile()") LoopConditionProfile profile) { + profile.profileCounted(operandLength); + while (profile.inject(operand.next(operandIter))) { + if (result != operand) { + result.next(resultIter); } + indexedAction.perform(node, result, resultIter, operand, operandIter); } + } +} + +public abstract class UnaryMapNode extends RBaseNode { - private interface MapIndexedAction<V> { - Object perform(UnaryMapFunctionNode action, V val); + @Child protected UnaryMapFunctionNode function; + protected final Class<? extends RAbstractVector> operandClass; + protected final RType argumentType; + protected final RType resultType; + + protected UnaryMapNode(UnaryMapFunctionNode function, RAbstractVector operand, RType argumentType, RType resultType) { + this.function = function; + this.operandClass = operand.getClass(); + this.argumentType = argumentType; + this.resultType = resultType; + } + + public static UnaryMapNode create(UnaryMapFunctionNode scalarNode, RAbstractVector operand, RType argumentType, RType resultType, boolean isGeneric) { + if (operand instanceof RScalarVector) { + return new UnaryMapScalarNode(scalarNode, operand, argumentType, resultType); + } else { + return new UnaryMapVectorNode(scalarNode, operand, argumentType, resultType, isGeneric); } } + + public abstract boolean isSupported(RAbstractVector operand); + + public abstract Object apply(RAbstractVector originalOperand); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java index 1f0892d7285603e2bf2950d8ba38a5b7df910ed3..170ac15bf2741ef642c1e946b024fe507c27fd92 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/UnaryArithmeticNode.java @@ -52,42 +52,45 @@ public abstract class UnaryArithmeticNode extends UnaryNode { public abstract Object execute(Object value); @Specialization(guards = {"cachedNode != null", "cachedNode.isSupported(operand)"}) - protected Object doCached(Object operand, + protected Object doCached(RAbstractVector operand, @Cached("createCachedFast(operand)") UnaryMapNode cachedNode) { return cachedNode.apply(operand); } - protected UnaryMapNode createCachedFast(Object operand) { + protected UnaryMapNode createCachedFast(RAbstractVector operand) { if (isNumericVector(operand)) { - return createCached(unary.createOperation(), operand); + return createCached(unary.createOperation(), operand, false); } return null; } - protected static UnaryMapNode createCached(UnaryArithmetic arithmetic, Object operand) { + protected static UnaryMapNode createCached(UnaryArithmetic arithmetic, Object operand, boolean isGeneric) { if (operand instanceof RAbstractVector) { RAbstractVector castOperand = (RAbstractVector) operand; RType operandType = castOperand.getRType(); if (operandType.isNumeric()) { RType type = RType.maxPrecedence(operandType, arithmetic.getMinPrecedence()); RType resultType = arithmetic.calculateResultType(type); - return UnaryMapNode.create(new ScalarUnaryArithmeticNode(arithmetic), castOperand, type, resultType); + return UnaryMapNode.create(new ScalarUnaryArithmeticNode(arithmetic), castOperand, type, resultType, isGeneric); } } return null; } - protected static boolean isNumericVector(Object value) { + protected static boolean isNumericVector(RAbstractVector value) { return value instanceof RAbstractIntVector || value instanceof RAbstractDoubleVector || value instanceof RAbstractComplexVector || value instanceof RAbstractLogicalVector; } @Specialization(replaces = "doCached", guards = {"isNumericVector(operand)"}) @TruffleBoundary - protected Object doGeneric(Object operand, + protected Object doGeneric(RAbstractVector operand, @Cached("unary.createOperation()") UnaryArithmetic arithmetic, - @Cached("new(createCached(arithmetic, operand))") GenericNumericVectorNode generic) { - RAbstractVector operandVector = (RAbstractVector) operand; - return generic.get(arithmetic, operandVector).apply(operandVector); + @Cached("createGeneric()") GenericNumericVectorNode generic) { + return generic.get(arithmetic, operand).apply(operand); + } + + protected static GenericNumericVectorNode createGeneric() { + return new GenericNumericVectorNode(); } @Override @@ -110,16 +113,12 @@ public abstract class UnaryArithmeticNode extends UnaryNode { @Child private UnaryMapNode cached; - public GenericNumericVectorNode(UnaryMapNode cachedOperation) { - this.cached = cachedOperation; - } - public UnaryMapNode get(UnaryArithmetic arithmetic, RAbstractVector operand) { - UnaryMapNode next = cached; - if (!next.isSupported(operand)) { - next = cached.replace(createCached(arithmetic, operand)); + UnaryMapNode map = cached; + if (map == null || !map.isSupported(operand)) { + cached = map = insert(createCached(arithmetic, operand, true)); } - return next; + return map; } } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java index 304d26eb4bd072fc48316aab7e49328df60f87a5..b7883f38067abbca3ed847c407596f6df09b8ad6 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/BinaryArithmetic.java @@ -124,6 +124,7 @@ public abstract class BinaryArithmetic extends Operation { public static final class Add extends BinaryArithmetic { + @CompilationFinal private boolean introducesOverflow = false; @CompilationFinal private boolean introducesNA = false; public Add() { @@ -137,17 +138,22 @@ public abstract class BinaryArithmetic extends Operation { @Override public boolean introducesNA() { - return introducesNA; + return introducesNA || introducesOverflow; } @Override public int op(int left, int right) { - if (!introducesNA) { + if (!introducesOverflow) { try { - return Math.addExact(left, right); + int result = Math.addExact(left, right); + // NAs can also be introduced without a 32-bit overflow + if (!introducesNA && result == RRuntime.INT_NA) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + introducesNA = true; + } } catch (ArithmeticException e) { CompilerDirectives.transferToInterpreterAndInvalidate(); - introducesNA = true; + introducesOverflow = true; } } // Borrowed from ExactMath @@ -176,6 +182,7 @@ public abstract class BinaryArithmetic extends Operation { public static final class Subtract extends BinaryArithmetic { + @CompilationFinal private boolean introducesOverflow = false; @CompilationFinal private boolean introducesNA = false; public Subtract() { @@ -189,17 +196,23 @@ public abstract class BinaryArithmetic extends Operation { @Override public boolean introducesNA() { - return introducesNA; + return introducesNA || introducesOverflow; } @Override public int op(int left, int right) { - if (!introducesNA) { + if (!introducesOverflow) { try { - return Math.subtractExact(left, right); + int result = Math.subtractExact(left, right); + // NAs can also be introduced without a 32-bit overflow + if (!introducesNA && result == RRuntime.INT_NA) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + introducesNA = true; + } + return result; } catch (ArithmeticException e) { CompilerDirectives.transferToInterpreterAndInvalidate(); - introducesNA = true; + introducesOverflow = true; } } // Borrowed from ExactMath