From 9dd84cc19989c6d9a1b51197135fad21dd1033aa Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Mon, 26 Sep 2016 19:09:25 +0200 Subject: [PATCH] proper type profiling in IntersectFastPath --- .../base/fastpaths/IntersectFastPath.java | 46 +++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java index e0f1161c1d..611cf68d89 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java @@ -36,35 +36,43 @@ import com.oracle.truffle.r.runtime.nodes.RNode; public abstract class IntersectFastPath extends RFastPathNode { + protected static final int TYPE_LIMIT = 2; + private static final int[] EMPTY_INT_ARRAY = new int[0]; - @Specialization(guards = {"x.getLength() > 0", "y.getLength() > 0"}) - protected RAbstractIntVector intersect(RAbstractIntVector x, RAbstractIntVector y, // - @Cached("createBinaryProfile()") ConditionProfile isXSortedProfile, // - @Cached("createBinaryProfile()") ConditionProfile isYSortedProfile, // + @Specialization(limit = "TYPE_LIMIT", guards = {"x.getLength() > 0", "y.getLength() > 0", "x.getClass() == xClass", "y.getClass() == yClass"}) + protected RAbstractIntVector intersect(RAbstractIntVector x, RAbstractIntVector y, + @Cached("x.getClass()") Class<? extends RAbstractIntVector> xClass, + @Cached("y.getClass()") Class<? extends RAbstractIntVector> yClass, + @Cached("createBinaryProfile()") ConditionProfile isXSortedProfile, + @Cached("createBinaryProfile()") ConditionProfile isYSortedProfile, @Cached("createBinaryProfile()") ConditionProfile resultLengthMatchProfile) { - int xLength = x.getLength(); - int yLength = y.getLength(); + // apply the type profiles: + RAbstractIntVector profiledX = xClass.cast(x); + RAbstractIntVector profiledY = yClass.cast(y); + + int xLength = profiledX.getLength(); + int yLength = profiledY.getLength(); RNode.reportWork(this, xLength + yLength); int count = 0; int[] result = EMPTY_INT_ARRAY; int maxResultLength = Math.min(xLength, yLength); - if (isXSortedProfile.profile(isSorted(x))) { + if (isXSortedProfile.profile(isSorted(profiledX))) { RAbstractIntVector tempY; - if (!isYSortedProfile.profile(isSorted(y))) { + if (isYSortedProfile.profile(isSorted(profiledY))) { + tempY = profiledY; + } else { int[] temp = new int[yLength]; for (int i = 0; i < yLength; i++) { - temp[i] = y.getDataAt(i); + temp[i] = profiledY.getDataAt(i); } sort(temp); - tempY = RDataFactory.createIntVector(temp, y.isComplete()); - } else { - tempY = y; + tempY = RDataFactory.createIntVector(temp, profiledY.isComplete()); } int xPos = 0; int yPos = 0; - int xValue = x.getDataAt(xPos); + int xValue = profiledX.getDataAt(xPos); int yValue = tempY.getDataAt(yPos); while (true) { if (xValue == yValue) { @@ -77,7 +85,7 @@ public abstract class IntersectFastPath extends RFastPathNode { if (xPos >= xLength - 1) { break; } - int nextValue = x.getDataAt(xPos + 1); + int nextValue = profiledX.getDataAt(xPos + 1); if (xValue != nextValue) { break; } @@ -87,13 +95,13 @@ public abstract class IntersectFastPath extends RFastPathNode { if (++xPos >= xLength || ++yPos >= yLength) { break; } - xValue = x.getDataAt(xPos); + xValue = profiledX.getDataAt(xPos); yValue = tempY.getDataAt(yPos); } else if (xValue < yValue) { if (++xPos >= xLength) { break; } - xValue = x.getDataAt(xPos); + xValue = profiledX.getDataAt(xPos); } else { if (++yPos >= yLength) { break; @@ -105,12 +113,12 @@ public abstract class IntersectFastPath extends RFastPathNode { int[] temp = new int[yLength]; boolean[] used = new boolean[yLength]; for (int i = 0; i < yLength; i++) { - temp[i] = y.getDataAt(i); + temp[i] = profiledY.getDataAt(i); } sort(temp); for (int i = 0; i < xLength; i++) { - int value = x.getDataAt(i); + int value = profiledX.getDataAt(i); int pos = Arrays.binarySearch(temp, value); if (pos >= 0 && !used[pos]) { used[pos] = true; @@ -121,7 +129,7 @@ public abstract class IntersectFastPath extends RFastPathNode { } } } - return RDataFactory.createIntVector(resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count), x.isComplete() | y.isComplete()); + return RDataFactory.createIntVector(resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count), profiledX.isComplete() | profiledY.isComplete()); } private static boolean isSorted(RAbstractIntVector vector) { -- GitLab