From 694eb3daf430188f40187d15a4f32e789de574f3 Mon Sep 17 00:00:00 2001 From: Florian Angerer <florian.angerer@oracle.com> Date: Fri, 25 Aug 2017 08:18:47 +0200 Subject: [PATCH] Fixed specializations with respect to fallback. --- .../truffle/r/nodes/unary/CastDoubleNode.java | 51 ++++++++++--------- .../r/nodes/unary/CastIntegerNode.java | 43 ++++++++-------- 2 files changed, 48 insertions(+), 46 deletions(-) diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java index 67ee125a0d..1429749749 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java @@ -31,6 +31,7 @@ import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RType; @@ -39,13 +40,13 @@ import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.interop.ForeignArray2R; import com.oracle.truffle.r.runtime.interop.ForeignArray2RNodeGen; @@ -92,31 +93,36 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { return vectorCopy(operand, ddata, !seenNA); } - @Specialization(guards = "useClosure()") - protected RAbstractDoubleVector doIntVectorReuse(RAbstractIntVector operand) { - return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractDoubleVector doLogicalVectorDimsReuse(RAbstractLogicalVector operand) { - return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractDoubleVector doRawVectorReuse(RRawVector operand) { - return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "!useClosure()") - protected RDoubleVector doIntVector(RAbstractIntVector operand) { + @Specialization + protected RAbstractDoubleVector doIntVector(RAbstractIntVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractIntVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> naCheck.convertIntToDouble(operand.getDataAt(index))); } - @Specialization(guards = "!useClosure()") - protected RDoubleVector doLogicalVectorDims(RAbstractLogicalVector operand) { + @Specialization + protected RAbstractDoubleVector doLogicalVector(RAbstractLogicalVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractLogicalVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> naCheck.convertLogicalToDouble(operand.getDataAt(index))); } + @Specialization + protected RAbstractDoubleVector doRawVector(RAbstractRawVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractRawVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile()); + } + return createResultVector(operand, index -> RRuntime.raw2double(operand.getDataAt(index))); + } + @Specialization protected RDoubleVector doStringVector(RStringVector operand, @Cached("createBinaryProfile()") ConditionProfile emptyStringProfile, @@ -172,11 +178,6 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode { return vectorCopy(operand, ddata, naCheck.neverSeenNA()); } - @Specialization(guards = "!useClosure()") - protected RDoubleVector doRawVector(RRawVector operand) { - return createResultVector(operand, index -> RRuntime.raw2double(operand.getDataAt(index))); - } - @Specialization protected RAbstractDoubleVector doDoubleVector(RAbstractDoubleVector operand) { return operand; diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java index 636037c3cf..b937b63904 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java @@ -28,6 +28,7 @@ import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.TruffleObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RType; @@ -160,33 +161,33 @@ public abstract class CastIntegerNode extends CastIntegerBaseNode { return vectorCopy(operand, idata, !seenNA); } - @Specialization(guards = "useClosure()") - public RAbstractIntVector doLogicalVectorReuse(RAbstractLogicalVector operand) { - return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractIntVector doDoubleVectorReuse(RAbstractDoubleVector operand) { - return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "useClosure()") - protected RAbstractIntVector doRawVectorReuse(RAbstractRawVector operand) { - return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); - } - - @Specialization(guards = "!useClosure()") - public RIntVector doLogicalVector(RAbstractLogicalVector operand) { + @Specialization + public RAbstractIntVector doLogicalVector(RAbstractLogicalVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractLogicalVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> naCheck.convertLogicalToInt(operand.getDataAt(index))); } - @Specialization(guards = "!useClosure()") - protected RIntVector doDoubleVector(RAbstractDoubleVector operand) { + @Specialization + protected RAbstractIntVector doDoubleVector(RAbstractDoubleVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractDoubleVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } return vectorCopy(operand, naCheck.convertDoubleVectorToIntData(operand), naCheck.neverSeenNA()); } - @Specialization(guards = "!useClosure()") - protected RIntVector doRawVector(RAbstractRawVector operand) { + @Specialization + protected RAbstractIntVector doRawVector(RAbstractRawVector x, + @Cached("createClassProfile()") ValueProfile operandTypeProfile) { + RAbstractRawVector operand = operandTypeProfile.profile(x); + if (useClosure()) { + return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile()); + } return createResultVector(operand, index -> RRuntime.raw2int(operand.getDataAt(index))); } -- GitLab