diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/SetDiffFastPath.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/SetDiffFastPath.java index 0345b8e638f83efcb6bc1dc2d4e4fde86cdb3798..e343b3190bfc51150b8bfdf46b3fb07668425e2c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/SetDiffFastPath.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/fastpaths/SetDiffFastPath.java @@ -22,6 +22,7 @@ */ package com.oracle.truffle.r.nodes.builtin.base.fastpaths; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -31,15 +32,17 @@ import com.oracle.truffle.r.runtime.nodes.RFastPathNode; public abstract class SetDiffFastPath extends RFastPathNode { - @Specialization(guards = "x.getStride() == 1") - protected Object setdiff(RIntSequence x, RAbstractIntVector y) { + @Specialization(guards = {"x.getStride() == 1", "y.getClass() == yClass"}) + protected Object setdiff(RIntSequence x, RAbstractIntVector y, + @Cached("y.getClass()") Class<? extends RAbstractIntVector> yClass) { + RAbstractIntVector profiledY = yClass.cast(y); int xLength = x.getLength(); int xStart = x.getStart(); - int yLength = y.getLength(); + int yLength = profiledY.getLength(); boolean[] excluded = new boolean[xLength]; for (int i = 0; i < yLength; i++) { - int element = y.getDataAt(i); + int element = profiledY.getDataAt(i); int index = element - xStart; if (index >= 0 && index < xLength) { excluded[index] = true; diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RFastPathNode.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RFastPathNode.java index 859d2c670d46d5646566adc7fb7c65c2811465cd..443f44c0c3b15e13f696724c7bd16279f8ef7a90 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RFastPathNode.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/nodes/RFastPathNode.java @@ -22,8 +22,11 @@ */ package com.oracle.truffle.r.runtime.nodes; +import com.oracle.truffle.api.dsl.TypeSystemReference; import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.r.runtime.data.RTypes; +@TypeSystemReference(RTypes.class) public abstract class RFastPathNode extends RBaseNode { public abstract Object execute(VirtualFrame frame, Object... args);