From 984629550dcaa999f44d28f5d1cf38d34f460312 Mon Sep 17 00:00:00 2001 From: Christian Humer <christian.humer@oracle.com> Date: Wed, 22 Nov 2017 19:21:04 +0100 Subject: [PATCH] Extract one more cached class to make the DSL happy. --- .../base/fastpaths/IntersectFastPath.java | 61 +++++++++++++------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java index f0f89021ec..268ab3930c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/IntersectFastPath.java @@ -138,47 +138,53 @@ public abstract class IntersectFastPath extends RFastPathNode { return new IntersectSortedNode(false); } - @Specialization(limit = "1", guards = {"x.getClass() == xClass", "y.getClass() == yClass", "length(x, xClass) > 0", "length(y, yClass) > 0"}, rewriteOn = IllegalArgumentException.class) + @Specialization(limit = "1", guards = {"x.getClass() == cached.xClass", "y.getClass() == cached.yClass", "length(x, cached.xClass) > 0", + "length(y, cached.yClass) > 0"}, rewriteOn = IllegalArgumentException.class) protected RAbstractIntVector intersectMaybeSorted(RAbstractIntVector x, RAbstractIntVector y, - @Cached("x.getClass()") Class<? extends RAbstractIntVector> xClass, - @Cached("y.getClass()") Class<? extends RAbstractIntVector> yClass, - @Cached("createMaybeSorted()") IntersectSortedNode intersect) { + @Cached("new(x.getClass(), y.getClass())") IntersectMaybeSortedNode cached) { // apply the type profiles: - RAbstractIntVector profiledX = xClass.cast(x); - RAbstractIntVector profiledY = yClass.cast(y); + RAbstractIntVector profiledX = cached.xClass.cast(x); + RAbstractIntVector profiledY = cached.yClass.cast(y); int xLength = profiledX.getLength(); int yLength = profiledY.getLength(); RBaseNode.reportWork(this, xLength + yLength); - int[] result = intersect.execute(profiledX, xLength, yLength, profiledY); + int[] result = cached.intersect.execute(profiledX, xLength, yLength, profiledY); return RDataFactory.createIntVector(result, profiledX.isComplete() | profiledY.isComplete()); } + public static class IntersectMaybeSortedNode extends Node { + public final Class<? extends RAbstractIntVector> xClass; + public final Class<? extends RAbstractIntVector> yClass; + @Child IntersectSortedNode intersect; + + public IntersectMaybeSortedNode(Class<? extends RAbstractIntVector> xClass, Class<? extends RAbstractIntVector> yClass) { + this.xClass = xClass; + this.yClass = yClass; + this.intersect = createMaybeSorted(); + } + } + protected static IntersectSortedNode createSorted() { return new IntersectSortedNode(true); } - @Specialization(limit = "1", guards = {"x.getClass() == xClass", "y.getClass() == yClass", "length(x, xClass) > 0", "length(y, yClass) > 0"}) + @Specialization(limit = "1", guards = {"x.getClass() == cached.xClass", "y.getClass() == cached.yClass", "length(x, cached.xClass) > 0", "length(y, cached.yClass) > 0"}) protected RAbstractIntVector intersect(RAbstractIntVector x, RAbstractIntVector y, - @Cached("x.getClass()") Class<? extends RAbstractIntVector> xClass, - @Cached("y.getClass()") Class<? extends RAbstractIntVector> yClass, - @Cached("createBinaryProfile()") ConditionProfile isXSortedProfile, - @Cached("createBinaryProfile()") ConditionProfile isYSortedProfile, - @Cached("createBinaryProfile()") ConditionProfile resultLengthMatchProfile, - @Cached("createSorted()") IntersectSortedNode intersect) { + @Cached("new(x.getClass(), y.getClass())") IntersectNode cached) { // apply the type profiles: - RAbstractIntVector profiledX = xClass.cast(x); - RAbstractIntVector profiledY = yClass.cast(y); + RAbstractIntVector profiledX = cached.xClass.cast(x); + RAbstractIntVector profiledY = cached.yClass.cast(y); int xLength = profiledX.getLength(); int yLength = profiledY.getLength(); RBaseNode.reportWork(this, xLength + yLength); int[] result; - if (isXSortedProfile.profile(isSorted(profiledX))) { + if (cached.isXSortedProfile.profile(isSorted(profiledX))) { RAbstractIntVector tempY; - if (isYSortedProfile.profile(isSorted(profiledY))) { + if (cached.isYSortedProfile.profile(isSorted(profiledY))) { tempY = profiledY; } else { int[] temp = new int[yLength]; @@ -188,7 +194,7 @@ public abstract class IntersectFastPath extends RFastPathNode { sort(temp); tempY = RDataFactory.createIntVector(temp, profiledY.isComplete()); } - result = intersect.execute(profiledX, xLength, yLength, tempY); + result = cached.intersect.execute(profiledX, xLength, yLength, tempY); } else { result = EMPTY_INT_ARRAY; int maxResultLength = Math.min(xLength, yLength); @@ -211,11 +217,26 @@ public abstract class IntersectFastPath extends RFastPathNode { result[count++] = value; } } - result = resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count); + result = cached.resultLengthMatchProfile.profile(count == result.length) ? result : Arrays.copyOf(result, count); } return RDataFactory.createIntVector(result, profiledX.isComplete() | profiledY.isComplete()); } + public static class IntersectNode extends Node { + public final Class<? extends RAbstractIntVector> xClass; + public final Class<? extends RAbstractIntVector> yClass; + final ConditionProfile isXSortedProfile = ConditionProfile.createBinaryProfile(); + final ConditionProfile isYSortedProfile = ConditionProfile.createBinaryProfile(); + final ConditionProfile resultLengthMatchProfile = ConditionProfile.createBinaryProfile(); + @Child IntersectSortedNode intersect; + + public IntersectNode(Class<? extends RAbstractIntVector> xClass, Class<? extends RAbstractIntVector> yClass) { + this.xClass = xClass; + this.yClass = yClass; + this.intersect = createMaybeSorted(); + } + } + private static boolean isSorted(RAbstractIntVector vector) { int length = vector.getLength(); int lastValue = vector.getDataAt(0); -- GitLab