diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java index 67ee125a0dec3d022d9f92eaae74a3f13ca67253..142974974969bfdce18caf7c210511097739755e 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java @@ -31,6 +31,7 @@ import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RType; @@ -39,13 +40,13 @@ import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.interop.ForeignArray2R; import com.oracle.truffle.r.runtime.interop.ForeignArray2RNodeGen; @@ -92,31 +93,36 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { return vectorCopy(operand, ddata, !seenNA); } - @Specialization(guards = "useClosure()") - protected RAbstractDoubleVector doIntVectorReuse(RAbstractIntVector operand) { - return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractDoubleVector doLogicalVectorDimsReuse(RAbstractLogicalVector operand) { - return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractDoubleVector doRawVectorReuse(RRawVector operand) { - return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "!useClosure()") - protected RDoubleVector doIntVector(RAbstractIntVector operand) { + @Specialization + protected RAbstractDoubleVector doIntVector(RAbstractIntVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractIntVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> naCheck.convertIntToDouble(operand.getDataAt(index))); } - @Specialization(guards = "!useClosure()") - protected RDoubleVector doLogicalVectorDims(RAbstractLogicalVector operand) { + @Specialization + protected RAbstractDoubleVector doLogicalVector(RAbstractLogicalVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractLogicalVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> naCheck.convertLogicalToDouble(operand.getDataAt(index))); } + @Specialization + protected RAbstractDoubleVector doRawVector(RAbstractRawVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractRawVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } + return createResultVector(operand, index -> RRuntime.raw2double(operand.getDataAt(index))); + } + @Specialization protected RDoubleVector doStringVector(RStringVector operand, @Cached("createBinaryProfile()") ConditionProfile emptyStringProfile, @@ -172,11 +178,6 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { return vectorCopy(operand, ddata, naCheck.neverSeenNA()); } - @Specialization(guards = "!useClosure()") - protected RDoubleVector doRawVector(RRawVector operand) { - return createResultVector(operand, index -> RRuntime.raw2double(operand.getDataAt(index))); - } - @Specialization protected RAbstractDoubleVector doDoubleVector(RAbstractDoubleVector operand) { return operand; diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java index 636037c3cf469d5c28c220d53131410507d39f7d..b937b63904c6c33163604e8d7008802d65a44041 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java @@ -28,6 +28,7 @@ import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RType; @@ -160,33 +161,33 @@ public abstract class CastIntegerNode extends CastIntegerBaseNode { return vectorCopy(operand, idata, !seenNA); } - @Specialization(guards = "useClosure()") - public RAbstractIntVector doLogicalVectorReuse(RAbstractLogicalVector operand) { - return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractIntVector doDoubleVectorReuse(RAbstractDoubleVector operand) { - return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractIntVector doRawVectorReuse(RAbstractRawVector operand) { - return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "!useClosure()") - public RIntVector doLogicalVector(RAbstractLogicalVector operand) { + @Specialization + public RAbstractIntVector doLogicalVector(RAbstractLogicalVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractLogicalVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> naCheck.convertLogicalToInt(operand.getDataAt(index))); } - @Specialization(guards = "!useClosure()") - protected RIntVector doDoubleVector(RAbstractDoubleVector operand) { + @Specialization + protected RAbstractIntVector doDoubleVector(RAbstractDoubleVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractDoubleVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } return vectorCopy(operand, naCheck.convertDoubleVectorToIntData(operand), naCheck.neverSeenNA()); } - @Specialization(guards = "!useClosure()") - protected RIntVector doRawVector(RAbstractRawVector operand) { + @Specialization + protected RAbstractIntVector doRawVector(RAbstractRawVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractRawVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> RRuntime.raw2int(operand.getDataAt(index))); }