From 9dd84cc19989c6d9a1b51197135fad21dd1033aa Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Mon, 26 Sep 2016 19:09:25 +0200
Subject: [PATCH] proper type profiling in IntersectFastPath

---
 .../base/fastpaths/IntersectFastPath.java     | 46 +++++++++++--------
 1 file changed, 27 insertions(+), 19 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java
index e0f1161c1d..611cf68d89 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java
@@ -36,35 +36,43 @@ import com.oracle.truffle.r.runtime.nodes.RNode;
 
 public abstract class IntersectFastPath extends RFastPathNode {
 
+    protected static final int TYPE_LIMIT = 2;
+
     private static final int[] EMPTY_INT_ARRAY = new int[0];
 
-    @Specialization(guards = {"x.getLength() > 0", "y.getLength() > 0"})
-    protected RAbstractIntVector intersect(RAbstractIntVector x, RAbstractIntVector y, //
-                    @Cached("createBinaryProfile()") ConditionProfile isXSortedProfile, //
-                    @Cached("createBinaryProfile()") ConditionProfile isYSortedProfile, //
+    @Specialization(limit = "TYPE_LIMIT", guards = {"x.getLength() > 0", "y.getLength() > 0", "x.getClass() == xClass", "y.getClass() == yClass"})
+    protected RAbstractIntVector intersect(RAbstractIntVector x, RAbstractIntVector y,
+                    @Cached("x.getClass()") Class<? extends RAbstractIntVector> xClass,
+                    @Cached("y.getClass()") Class<? extends RAbstractIntVector> yClass,
+                    @Cached("createBinaryProfile()") ConditionProfile isXSortedProfile,
+                    @Cached("createBinaryProfile()") ConditionProfile isYSortedProfile,
                     @Cached("createBinaryProfile()") ConditionProfile resultLengthMatchProfile) {
-        int xLength = x.getLength();
-        int yLength = y.getLength();
+        // apply the type profiles:
+        RAbstractIntVector profiledX = xClass.cast(x);
+        RAbstractIntVector profiledY = yClass.cast(y);
+
+        int xLength = profiledX.getLength();
+        int yLength = profiledY.getLength();
         RNode.reportWork(this, xLength + yLength);
 
         int count = 0;
         int[] result = EMPTY_INT_ARRAY;
         int maxResultLength = Math.min(xLength, yLength);
-        if (isXSortedProfile.profile(isSorted(x))) {
+        if (isXSortedProfile.profile(isSorted(profiledX))) {
             RAbstractIntVector tempY;
-            if (!isYSortedProfile.profile(isSorted(y))) {
+            if (isYSortedProfile.profile(isSorted(profiledY))) {
+                tempY = profiledY;
+            } else {
                 int[] temp = new int[yLength];
                 for (int i = 0; i < yLength; i++) {
-                    temp[i] = y.getDataAt(i);
+                    temp[i] = profiledY.getDataAt(i);
                 }
                 sort(temp);
-                tempY = RDataFactory.createIntVector(temp, y.isComplete());
-            } else {
-                tempY = y;
+                tempY = RDataFactory.createIntVector(temp, profiledY.isComplete());
             }
             int xPos = 0;
             int yPos = 0;
-            int xValue = x.getDataAt(xPos);
+            int xValue = profiledX.getDataAt(xPos);
             int yValue = tempY.getDataAt(yPos);
             while (true) {
                 if (xValue == yValue) {
@@ -77,7 +85,7 @@ public abstract class IntersectFastPath extends RFastPathNode {
                         if (xPos >= xLength - 1) {
                             break;
                         }
-                        int nextValue = x.getDataAt(xPos + 1);
+                        int nextValue = profiledX.getDataAt(xPos + 1);
                         if (xValue != nextValue) {
                             break;
                         }
@@ -87,13 +95,13 @@ public abstract class IntersectFastPath extends RFastPathNode {
                     if (++xPos >= xLength || ++yPos >= yLength) {
                         break;
                     }
-                    xValue = x.getDataAt(xPos);
+                    xValue = profiledX.getDataAt(xPos);
                     yValue = tempY.getDataAt(yPos);
                 } else if (xValue < yValue) {
                     if (++xPos >= xLength) {
                         break;
                     }
-                    xValue = x.getDataAt(xPos);
+                    xValue = profiledX.getDataAt(xPos);
                 } else {
                     if (++yPos >= yLength) {
                         break;
@@ -105,12 +113,12 @@ public abstract class IntersectFastPath extends RFastPathNode {
             int[] temp = new int[yLength];
             boolean[] used = new boolean[yLength];
             for (int i = 0; i < yLength; i++) {
-                temp[i] = y.getDataAt(i);
+                temp[i] = profiledY.getDataAt(i);
             }
             sort(temp);
 
             for (int i = 0; i < xLength; i++) {
-                int value = x.getDataAt(i);
+                int value = profiledX.getDataAt(i);
                 int pos = Arrays.binarySearch(temp, value);
                 if (pos >= 0 && !used[pos]) {
                     used[pos] = true;
@@ -121,7 +129,7 @@ public abstract class IntersectFastPath extends RFastPathNode {
                 }
             }
         }
-        return RDataFactory.createIntVector(resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count), x.isComplete() | y.isComplete());
+        return RDataFactory.createIntVector(resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count), profiledX.isComplete() | profiledY.isComplete());
     }
 
     private static boolean isSorted(RAbstractIntVector vector) {
-- 
GitLab