From a4fe0cd1a9d6ed42f6fc3792ecee18bbe65b54f1 Mon Sep 17 00:00:00 2001 From: Florian Angerer <florian.angerer@oracle.com> Date: Wed, 23 Aug 2017 13:31:32 +0200 Subject: [PATCH] Made reuse of non-shared vector optional in cast pipeline. --- .../truffle/r/nodes/builtin/base/APerm.java | 2 +- .../truffle/r/nodes/builtin/CastBuilder.java | 14 ++++-- .../r/nodes/builtin/casts/PipelineStep.java | 12 +++-- .../builtin/casts/PipelineToCastNode.java | 8 +-- .../builtin/casts/fluent/PipelineBuilder.java | 4 +- .../truffle/r/nodes/unary/CastBaseNode.java | 36 ++++++++++--- .../r/nodes/unary/CastDoubleBaseNode.java | 4 +- .../truffle/r/nodes/unary/CastDoubleNode.java | 50 ++++++++++++------- .../r/nodes/unary/CastIntegerBaseNode.java | 8 +-- .../r/nodes/unary/CastIntegerNode.java | 48 +++++++++++------- 10 files changed, 122 insertions(+), 64 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/APerm.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/APerm.java index 0938a7bde6..a4f4376e81 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/APerm.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/APerm.java @@ -55,7 +55,7 @@ public abstract class APerm extends RBuiltinNode.Arg3 { static { Casts casts = new Casts(APerm.class); casts.arg("a").mustNotBeNull(RError.Message.FIRST_ARG_MUST_BE_ARRAY); - casts.arg("perm").allowNull().mustBe(numericValue().or(stringValue()).or(complexValue())).mapIf(numericValue().or(complexValue()), asIntegerVector()); + casts.arg("perm").allowNull().mustBe(numericValue().or(stringValue()).or(complexValue())).mapIf(numericValue().or(complexValue()), asIntegerVector(true)); casts.arg("resize").mustBe(numericValue().or(logicalValue()), Message.INVALID_LOGICAL, "resize").asLogicalVector().findFirst(); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java index 70f34a3d31..85c0769241 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java @@ -332,8 +332,16 @@ public final class CastBuilder { return new CoercionStep<>(RType.Integer, true); } + public static <T> PipelineStep<T, RAbstractIntVector> asIntegerVector(boolean reuseNonShared) { + return new CoercionStep<>(RType.Integer, true, false, false, false, true, reuseNonShared); + } + public static <T> PipelineStep<T, RAbstractIntVector> asIntegerVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - return new CoercionStep<>(RType.Integer, true, preserveNames, preserveDimensions, preserveAttributes, true); + return new CoercionStep<>(RType.Integer, true, preserveNames, preserveDimensions, preserveAttributes, true, false); + } + + public static <T> PipelineStep<T, RAbstractIntVector> asIntegerVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean reuseNonShared) { + return new CoercionStep<>(RType.Integer, true, preserveNames, preserveDimensions, preserveAttributes, true, reuseNonShared); } public static <T> PipelineStep<T, Double> asDouble() { @@ -377,7 +385,7 @@ public final class CastBuilder { } public static <T> PipelineStep<T, RAbstractLogicalVector> asLogicalVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - return new CoercionStep<>(RType.Logical, true, preserveNames, preserveDimensions, preserveAttributes, false); + return new CoercionStep<>(RType.Logical, true, preserveNames, preserveDimensions, preserveAttributes, false, false); } public static PipelineStep<Byte, Boolean> asBoolean() { @@ -389,7 +397,7 @@ public final class CastBuilder { } public static <T> PipelineStep<T, RAbstractVector> asVector(boolean preserveNonVector) { - return new CoercionStep<>(RType.Any, true, false, false, false, preserveNonVector); + return new CoercionStep<>(RType.Any, true, false, false, false, preserveNonVector, false); } /** diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java index 7f4776692e..c172a54400 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java @@ -176,21 +176,27 @@ public abstract class PipelineStep<T, R> { */ public final boolean vectorCoercion; + /** + * Whether the cast should reuse a non-shared vector. + */ + public final boolean reuseNonShared; + public CoercionStep(RType type, boolean vectorCoercion) { - this(type, vectorCoercion, false, false, false, true); + this(type, vectorCoercion, false, false, false, true, false); } public CoercionStep(RType type, boolean vectorCoercion, boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - this(type, vectorCoercion, preserveNames, preserveDimensions, preserveAttributes, true); + this(type, vectorCoercion, preserveNames, preserveDimensions, preserveAttributes, true, false); } - public CoercionStep(RType type, boolean vectorCoercion, boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean preserveNonVector) { + public CoercionStep(RType type, boolean vectorCoercion, boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean preserveNonVector, boolean reuseNonShared) { this.type = type; this.vectorCoercion = vectorCoercion; this.preserveNames = preserveNames; this.preserveAttributes = preserveAttributes; this.preserveDimensions = preserveDimensions; this.preserveNonVector = preserveNonVector; + this.reuseNonShared = reuseNonShared; } public RType getType() { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java index 12aaa0ec1f..f9f7dde80c 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java @@ -226,11 +226,11 @@ public final class PipelineToCastNode { RType type = step.getType(); switch (type) { case Integer: - return step.vectorCoercion ? CastIntegerNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes) - : CastIntegerBaseNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes); + return step.vectorCoercion ? CastIntegerNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes, false, step.reuseNonShared) + : CastIntegerBaseNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes, false, step.reuseNonShared); case Double: - return step.vectorCoercion ? CastDoubleNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes) - : CastDoubleBaseNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes); + return step.vectorCoercion ? CastDoubleNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes, false, step.reuseNonShared) + : CastDoubleBaseNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes, false, step.reuseNonShared); case Character: return step.vectorCoercion ? CastStringNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes) : CastStringBaseNodeGen.create(step.preserveNames, step.preserveDimensions, step.preserveAttributes); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java index 5637486524..b0be4bf824 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java @@ -76,12 +76,12 @@ public final class PipelineBuilder { } public void appendAsVector(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean preserveNonVector) { - append(new CoercionStep<>(RType.Any, true, preserveNames, preserveDimensions, preserveAttributes, preserveNonVector)); + append(new CoercionStep<>(RType.Any, true, preserveNames, preserveDimensions, preserveAttributes, preserveNonVector, false)); } public void appendAsVector(RType type, boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { assert type == RType.Integer || type == RType.Double || type == RType.Complex || type == RType.Character || type == RType.Logical || type == RType.Raw; - append(new CoercionStep<>(type, true, preserveNames, preserveDimensions, preserveAttributes, true)); + append(new CoercionStep<>(type, true, preserveNames, preserveDimensions, preserveAttributes, true, false)); } public void appendNotNA(Object naReplacement, Message message, Object[] messageArgs) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastBaseNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastBaseNode.java index 2ff5495c46..cd0d124a10 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastBaseNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastBaseNode.java @@ -28,6 +28,7 @@ import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; @@ -43,24 +44,18 @@ import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RSharingAttributeStorage; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RTypedValue; -import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.env.REnvironment; public abstract class CastBaseNode extends CastNode { - protected static boolean isReusable(RAbstractVector v) { - if (v instanceof RSharingAttributeStorage) { - return !((RSharingAttributeStorage) v).isShared(); - } - return false; - } - private final BranchProfile listCoercionErrorBranch = BranchProfile.create(); private final ConditionProfile hasDimNamesProfile = ConditionProfile.createBinaryProfile(); private final NullProfile hasDimensionsProfile = NullProfile.create(); private final NullProfile hasNamesProfile = NullProfile.create(); + private final ValueProfile reuseClassProfile; + @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); @Child private GetDimAttributeNode getDimNode; @Child private SetDimNamesAttributeNode setDimNamesNode; @@ -70,6 +65,9 @@ public abstract class CastBaseNode extends CastNode { private final boolean preserveDimensions; private final boolean preserveAttributes; + /** {@code true} if a cast should try to reuse a non-shared vector. */ + private final boolean reuseNonShared; + /** * GnuR provides several, sometimes incompatible, ways to coerce given value to given type. This * flag tells the cast node that it should behave in a way compatible with functions exposed by @@ -82,6 +80,10 @@ public abstract class CastBaseNode extends CastNode { } protected CastBaseNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI) { + this(preserveNames, preserveDimensions, preserveAttributes, forRFFI, false); + } + + protected CastBaseNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI, boolean reuseNonShared) { this.preserveNames = preserveNames; this.preserveDimensions = preserveDimensions; this.preserveAttributes = preserveAttributes; @@ -89,6 +91,8 @@ public abstract class CastBaseNode extends CastNode { if (preserveDimensions) { getDimNamesNode = GetDimNamesAttributeNode.create(); } + this.reuseNonShared = reuseNonShared; + reuseClassProfile = reuseNonShared ? ValueProfile.createClassProfile() : null; } public final boolean preserveNames() { @@ -107,6 +111,10 @@ public abstract class CastBaseNode extends CastNode { return preserveAttributes || preserveNames || preserveDimensions; } + public final boolean reuseNonShared() { + return reuseNonShared; + } + protected abstract RType getTargetType(); protected RError throwCannotCoerceListError(String type) { @@ -174,4 +182,16 @@ public abstract class CastBaseNode extends CastNode { } return RNull.instance; } + + protected boolean isReusable(RAbstractVector v) { + if (reuseNonShared && v instanceof RSharingAttributeStorage) { + return !((RSharingAttributeStorage) v).isShared(); + } + return false; + } + + protected RAbstractVector castWithReuse(RType targetType, RAbstractVector v, ConditionProfile naProfile) { + assert isReusable(v); + return reuseClassProfile.profile(v.castSafe(targetType, naProfile, preserveAttributes())); + } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleBaseNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleBaseNode.java index 9abeb67a8d..6fee221df0 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleBaseNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleBaseNode.java @@ -44,8 +44,8 @@ public abstract class CastDoubleBaseNode extends CastBaseNode { super(preserveNames, preserveDimensions, preserveAttributes); } - protected CastDoubleBaseNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI) { - super(preserveNames, preserveDimensions, preserveAttributes, forRFFI); + protected CastDoubleBaseNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI, boolean withReuse) { + super(preserveNames, preserveDimensions, preserveAttributes, forRFFI, withReuse); } @Override diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java index 40e9f22fcd..5f8e94ac36 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java @@ -53,12 +53,12 @@ import com.oracle.truffle.r.runtime.interop.ForeignArray2RNodeGen; @ImportStatic(RRuntime.class) public abstract class CastDoubleNode extends CastDoubleBaseNode { - protected CastDoubleNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI) { - super(preserveNames, preserveDimensions, preserveAttributes, forRFFI); + protected CastDoubleNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI, boolean withReuse) { + super(preserveNames, preserveDimensions, preserveAttributes, forRFFI, withReuse); } protected CastDoubleNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - super(preserveNames, preserveDimensions, preserveAttributes); + super(preserveNames, preserveDimensions, preserveAttributes, false, false); } @Child private CastDoubleNode recursiveCastDouble; @@ -66,7 +66,7 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { private Object castDoubleRecursive(Object o) { if (recursiveCastDouble == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); - recursiveCastDouble = insert(CastDoubleNodeGen.create(preserveNames(), preserveDimensions(), preserveRegAttributes())); + recursiveCastDouble = insert(CastDoubleNodeGen.create(preserveNames(), preserveDimensions(), preserveRegAttributes(), false, reuseNonShared())); } return recursiveCastDouble.executeDouble(o); } @@ -92,14 +92,29 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { return vectorCopy(operand, ddata, !seenNA); } + @Specialization(guards = "isReusable(operand)") + protected RAbstractDoubleVector doIntVectorReuse(RAbstractIntVector operand) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } + + @Specialization(guards = "isReusable(operand)") + protected RAbstractDoubleVector doLogicalVectorDimsReuse(RAbstractLogicalVector operand) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } + + @Specialization(guards = "isReusable(operand)") + protected RAbstractDoubleVector doRawVectorReuse(RRawVector operand) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } + @Specialization - protected RAbstractDoubleVector doIntVector(RAbstractIntVector operand) { - return castWithReuse(operand, index -> naCheck.convertIntToDouble(operand.getDataAt(index))); + protected RDoubleVector doIntVector(RAbstractIntVector operand) { + return createResultVector(operand, index -> naCheck.convertIntToDouble(operand.getDataAt(index))); } @Specialization - protected RAbstractDoubleVector doLogicalVectorDims(RAbstractLogicalVector operand) { - return castWithReuse(operand, index -> naCheck.convertLogicalToDouble(operand.getDataAt(index))); + protected RDoubleVector doLogicalVectorDims(RAbstractLogicalVector operand) { + return createResultVector(operand, index -> naCheck.convertLogicalToDouble(operand.getDataAt(index))); } @Specialization @@ -158,8 +173,8 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { } @Specialization - protected RAbstractDoubleVector doRawVector(RRawVector operand) { - return castWithReuse(operand, index -> RRuntime.raw2double(operand.getDataAt(index))); + protected RDoubleVector doRawVector(RRawVector operand) { + return createResultVector(operand, index -> RRuntime.raw2double(operand.getDataAt(index))); } @Specialization @@ -223,23 +238,20 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { throw error(RError.Message.CANNOT_COERCE_EXTERNAL_OBJECT_TO_VECTOR, "vector"); } - private RAbstractDoubleVector castWithReuse(RAbstractVector v, IntToDoubleFunction elementFunction) { - if (isReusable(v)) { - return (RAbstractDoubleVector) v.castSafe(RType.Double, naProfile.getConditionProfile(), preserveAttributes()); - } - return createResultVector(v, elementFunction); + public static CastDoubleNode create() { + return CastDoubleNodeGen.create(true, true, true, false, false); } - public static CastDoubleNode create() { - return CastDoubleNodeGen.create(true, true, true); + public static CastDoubleNode createWithReuse() { + return CastDoubleNodeGen.create(true, true, true, false, true); } public static CastDoubleNode createForRFFI(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - return CastDoubleNodeGen.create(preserveNames, preserveDimensions, preserveAttributes, true); + return CastDoubleNodeGen.create(preserveNames, preserveDimensions, preserveAttributes, true, false); } public static CastDoubleNode createNonPreserving() { - return CastDoubleNodeGen.create(false, false, false); + return CastDoubleNodeGen.create(false, false, false, false, false); } protected ForeignArray2R createForeignArray2RNode() { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerBaseNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerBaseNode.java index 18c5757508..10e420d5b1 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerBaseNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerBaseNode.java @@ -41,12 +41,12 @@ public abstract class CastIntegerBaseNode extends CastBaseNode { @Child private CastIntegerNode recursiveCastInteger; - protected CastIntegerBaseNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI) { - super(preserveNames, preserveDimensions, preserveAttributes, forRFFI); + protected CastIntegerBaseNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI, boolean withReuse) { + super(preserveNames, preserveDimensions, preserveAttributes, forRFFI, withReuse); } protected CastIntegerBaseNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - super(preserveNames, preserveDimensions, preserveAttributes); + super(preserveNames, preserveDimensions, preserveAttributes, false, false); } @Override @@ -57,7 +57,7 @@ public abstract class CastIntegerBaseNode extends CastBaseNode { protected Object castIntegerRecursive(Object o) { if (recursiveCastInteger == null) { CompilerDirectives.transferToInterpreterAndInvalidate(); - recursiveCastInteger = insert(CastIntegerNodeGen.create(preserveNames(), preserveDimensions(), preserveRegAttributes())); + recursiveCastInteger = insert(CastIntegerNodeGen.create(preserveNames(), preserveDimensions(), preserveRegAttributes(), false, reuseNonShared())); } return recursiveCastInteger.executeInt(o); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java index 0d7552c06c..997193d66f 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java @@ -57,12 +57,12 @@ public abstract class CastIntegerNode extends CastIntegerBaseNode { private final BranchProfile warningBranch = BranchProfile.create(); - protected CastIntegerNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI) { - super(preserveNames, preserveDimensions, preserveAttributes, forRFFI); + protected CastIntegerNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes, boolean forRFFI, boolean withReuse) { + super(preserveNames, preserveDimensions, preserveAttributes, forRFFI, withReuse); } protected CastIntegerNode(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - super(preserveNames, preserveDimensions, preserveAttributes); + super(preserveNames, preserveDimensions, preserveAttributes, false, false); } public abstract Object executeInt(int o); @@ -160,19 +160,34 @@ public abstract class CastIntegerNode extends CastIntegerBaseNode { return vectorCopy(operand, idata, !seenNA); } + @Specialization(guards = "isReusable(operand)") + public RAbstractIntVector doLogicalVectorReuse(RAbstractLogicalVector operand) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } + + @Specialization(guards = "isReusable(operand)") + protected RAbstractIntVector doDoubleVectorReuse(RAbstractDoubleVector operand) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } + + @Specialization(guards = "isReusable(operand)") + protected RAbstractIntVector doRawVectorReuse(RAbstractRawVector operand) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } + @Specialization - public RAbstractIntVector doLogicalVector(RAbstractLogicalVector operand) { - return castWithReuse(operand, index -> naCheck.convertLogicalToInt(operand.getDataAt(index))); + public RIntVector doLogicalVector(RAbstractLogicalVector operand) { + return createResultVector(operand, index -> naCheck.convertLogicalToInt(operand.getDataAt(index))); } @Specialization - protected RAbstractIntVector doDoubleVector(RAbstractDoubleVector operand) { - return castWithReuse(operand, index -> naCheck.convertDoubleToInt(operand.getDataAt(index))); + protected RIntVector doDoubleVector(RAbstractDoubleVector operand) { + return createResultVector(operand, index -> naCheck.convertDoubleToInt(operand.getDataAt(index))); } @Specialization - protected RAbstractIntVector doRawVector(RAbstractRawVector operand) { - return castWithReuse(operand, index -> RRuntime.raw2int(operand.getDataAt(index))); + protected RIntVector doRawVector(RAbstractRawVector operand) { + return createResultVector(operand, index -> RRuntime.raw2int(operand.getDataAt(index))); } @Specialization @@ -241,23 +256,20 @@ public abstract class CastIntegerNode extends CastIntegerBaseNode { return arg instanceof RIntVector; } - private RAbstractIntVector castWithReuse(RAbstractVector v, IntToIntFunction elementFunction) { - if (isReusable(v)) { - return (RAbstractIntVector) v.castSafe(RType.Integer, naProfile.getConditionProfile(), preserveAttributes()); - } - return createResultVector(v, elementFunction); + public static CastIntegerNode create() { + return CastIntegerNodeGen.create(true, true, true, false, false); } - public static CastIntegerNode create() { - return CastIntegerNodeGen.create(true, true, true); + public static CastIntegerNode createWithReuse() { + return CastIntegerNodeGen.create(true, true, true, false, true); } public static CastIntegerNode createForRFFI(boolean preserveNames, boolean preserveDimensions, boolean preserveAttributes) { - return CastIntegerNodeGen.create(preserveNames, preserveDimensions, preserveAttributes, true); + return CastIntegerNodeGen.create(preserveNames, preserveDimensions, preserveAttributes, true, false); } public static CastIntegerNode createNonPreserving() { - return CastIntegerNodeGen.create(false, false, false); + return CastIntegerNodeGen.create(false, false, false, false, false); } protected ForeignArray2R createForeignArray2RNode() { -- GitLab