From 694eb3daf430188f40187d15a4f32e789de574f3 Mon Sep 17 00:00:00 2001
From: Florian Angerer <florian.angerer@oracle.com>
Date: Fri, 25 Aug 2017 08:18:47 +0200
Subject: [PATCH] Fixed specializations with respect to fallback.

---
 .../truffle/r/nodes/unary/CastDoubleNode.java | 51 ++++++++++---------
 .../r/nodes/unary/CastIntegerNode.java        | 43 ++++++++--------
 2 files changed, 48 insertions(+), 46 deletions(-)

diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java
index 67ee125a0d..1429749749 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastDoubleNode.java
@@ -31,6 +31,7 @@ import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.interop.TruffleObject;
 import com.oracle.truffle.api.profiles.BranchProfile;
 import com.oracle.truffle.api.profiles.ConditionProfile;
+import com.oracle.truffle.api.profiles.ValueProfile;
 import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.RType;
@@ -39,13 +40,13 @@ import com.oracle.truffle.r.runtime.data.RComplexVector;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
 import com.oracle.truffle.r.runtime.data.RDoubleVector;
 import com.oracle.truffle.r.runtime.data.RList;
-import com.oracle.truffle.r.runtime.data.RRawVector;
 import com.oracle.truffle.r.runtime.data.RStringVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractContainer;
 import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractListVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
+import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
 import com.oracle.truffle.r.runtime.interop.ForeignArray2R;
 import com.oracle.truffle.r.runtime.interop.ForeignArray2RNodeGen;
@@ -92,31 +93,36 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode {
         return vectorCopy(operand, ddata, !seenNA);
     }
 
-    @Specialization(guards = "useClosure()")
-    protected RAbstractDoubleVector doIntVectorReuse(RAbstractIntVector operand) {
-        return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile());
-    }
-
-    @Specialization(guards = "useClosure()")
-    protected RAbstractDoubleVector doLogicalVectorDimsReuse(RAbstractLogicalVector operand) {
-        return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile());
-    }
-
-    @Specialization(guards = "useClosure()")
-    protected RAbstractDoubleVector doRawVectorReuse(RRawVector operand) {
-        return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile());
-    }
-
-    @Specialization(guards = "!useClosure()")
-    protected RDoubleVector doIntVector(RAbstractIntVector operand) {
+    @Specialization
+    protected RAbstractDoubleVector doIntVector(RAbstractIntVector x,
+                    @Cached("createClassProfile()") ValueProfile operandTypeProfile) {
+        RAbstractIntVector operand = operandTypeProfile.profile(x);
+        if (useClosure()) {
+            return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile());
+        }
         return createResultVector(operand, index -> naCheck.convertIntToDouble(operand.getDataAt(index)));
     }
 
-    @Specialization(guards = "!useClosure()")
-    protected RDoubleVector doLogicalVectorDims(RAbstractLogicalVector operand) {
+    @Specialization
+    protected RAbstractDoubleVector doLogicalVector(RAbstractLogicalVector x,
+                    @Cached("createClassProfile()") ValueProfile operandTypeProfile) {
+        RAbstractLogicalVector operand = operandTypeProfile.profile(x);
+        if (useClosure()) {
+            return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile());
+        }
         return createResultVector(operand, index -> naCheck.convertLogicalToDouble(operand.getDataAt(index)));
     }
 
+    @Specialization
+    protected RAbstractDoubleVector doRawVector(RAbstractRawVector x,
+                    @Cached("createClassProfile()") ValueProfile operandTypeProfile) {
+        RAbstractRawVector operand = operandTypeProfile.profile(x);
+        if (useClosure()) {
+            return (RAbstractDoubleVector) castWithReuse(RType.Double, operand, naProfile.getConditionProfile());
+        }
+        return createResultVector(operand, index -> RRuntime.raw2double(operand.getDataAt(index)));
+    }
+
     @Specialization
     protected RDoubleVector doStringVector(RStringVector operand,
                     @Cached("createBinaryProfile()") ConditionProfile emptyStringProfile,
@@ -172,11 +178,6 @@ public abstract class CastDoubleNode extends CastDoubleBaseNode {
         return vectorCopy(operand, ddata, naCheck.neverSeenNA());
     }
 
-    @Specialization(guards = "!useClosure()")
-    protected RDoubleVector doRawVector(RRawVector operand) {
-        return createResultVector(operand, index -> RRuntime.raw2double(operand.getDataAt(index)));
-    }
-
     @Specialization
     protected RAbstractDoubleVector doDoubleVector(RAbstractDoubleVector operand) {
         return operand;
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java
index 636037c3cf..b937b63904 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/CastIntegerNode.java
@@ -28,6 +28,7 @@ import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.interop.TruffleObject;
 import com.oracle.truffle.api.profiles.BranchProfile;
 import com.oracle.truffle.api.profiles.ConditionProfile;
+import com.oracle.truffle.api.profiles.ValueProfile;
 import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.RType;
@@ -160,33 +161,33 @@ public abstract class CastIntegerNode extends CastIntegerBaseNode {
         return vectorCopy(operand, idata, !seenNA);
     }
 
-    @Specialization(guards = "useClosure()")
-    public RAbstractIntVector doLogicalVectorReuse(RAbstractLogicalVector operand) {
-        return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile());
-    }
-
-    @Specialization(guards = "useClosure()")
-    protected RAbstractIntVector doDoubleVectorReuse(RAbstractDoubleVector operand) {
-        return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile());
-    }
-
-    @Specialization(guards = "useClosure()")
-    protected RAbstractIntVector doRawVectorReuse(RAbstractRawVector operand) {
-        return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile());
-    }
-
-    @Specialization(guards = "!useClosure()")
-    public RIntVector doLogicalVector(RAbstractLogicalVector operand) {
+    @Specialization
+    public RAbstractIntVector doLogicalVector(RAbstractLogicalVector x,
+                    @Cached("createClassProfile()") ValueProfile operandTypeProfile) {
+        RAbstractLogicalVector operand = operandTypeProfile.profile(x);
+        if (useClosure()) {
+            return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile());
+        }
         return createResultVector(operand, index -> naCheck.convertLogicalToInt(operand.getDataAt(index)));
     }
 
-    @Specialization(guards = "!useClosure()")
-    protected RIntVector doDoubleVector(RAbstractDoubleVector operand) {
+    @Specialization
+    protected RAbstractIntVector doDoubleVector(RAbstractDoubleVector x,
+                    @Cached("createClassProfile()") ValueProfile operandTypeProfile) {
+        RAbstractDoubleVector operand = operandTypeProfile.profile(x);
+        if (useClosure()) {
+            return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile());
+        }
         return vectorCopy(operand, naCheck.convertDoubleVectorToIntData(operand), naCheck.neverSeenNA());
     }
 
-    @Specialization(guards = "!useClosure()")
-    protected RIntVector doRawVector(RAbstractRawVector operand) {
+    @Specialization
+    protected RAbstractIntVector doRawVector(RAbstractRawVector x,
+                    @Cached("createClassProfile()") ValueProfile operandTypeProfile) {
+        RAbstractRawVector operand = operandTypeProfile.profile(x);
+        if (useClosure()) {
+            return (RAbstractIntVector) castWithReuse(RType.Integer, operand, naProfile.getConditionProfile());
+        }
         return createResultVector(operand, index -> RRuntime.raw2int(operand.getDataAt(index)));
     }
 
-- 
GitLab