Skip to content
Snippets Groups Projects
Commit ce2df48e authored by Mick Jordan's avatar Mick Jordan
Browse files

seq: add additional fast-path specializations

parent ae27a0c6
No related branches found
No related tags found
No related merge requests found
......@@ -29,9 +29,11 @@ import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetClassAttributeNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.SeqFunctions.SeqInt.IsIntegralNumericNode;
import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.GetIntegralNumericNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.IsMissingOrNumericNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.IsNumericNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen.GetIntegralNumericNodeGen;
import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen.IsIntegralNumericNodeGen;
import com.oracle.truffle.r.nodes.control.RLengthNode;
import com.oracle.truffle.r.nodes.control.RLengthNodeGen;
......@@ -85,46 +87,119 @@ public class SeqFunctions {
public static IsMissingOrNumericNode createIsMissingOrNumericNode() {
return IsMissingOrNumericNodeGen.create();
}
public static IsNumericNode createIsNumericNode() {
return IsNumericNodeGen.create();
}
}
@TypeSystemReference(RTypesFlatLayout.class)
@SuppressWarnings("unused")
public abstract static class IsMissingOrNumericNode extends Node {
public abstract static class IsNumericNode extends Node {
public abstract boolean execute(Object obj);
@Specialization
protected boolean isMissingOrNumericNode(RMissing obj) {
protected boolean isNumericNode(Integer obj) {
return true;
}
@Specialization
protected boolean isMissingOrNumericNode(Integer obj) {
protected boolean isNumericNode(Double obj) {
return true;
}
@Specialization
protected boolean isMissingOrNumericNode(Double obj) {
protected boolean isNumericNode(RAbstractIntVector obj) {
return true;
}
@Specialization
protected boolean isMissingOrNumericNode(RAbstractIntVector obj) {
protected boolean isNumericNode(RAbstractDoubleVector obj) {
return true;
}
@Fallback
protected boolean isNumericNode(Object obj) {
return false;
}
}
@TypeSystemReference(RTypesFlatLayout.class)
@SuppressWarnings("unused")
public abstract static class IsMissingOrNumericNode extends IsNumericNode {
@Specialization
protected boolean isMissingOrNumericNode(RAbstractDoubleVector obj) {
protected boolean isMissingOrNumericNode(RMissing obj) {
return true;
}
}
@TypeSystemReference(RTypesFlatLayout.class)
public abstract static class GetIntegralNumericNode extends Node {
public abstract int execute(Object obj);
@Specialization
protected int getIntegralNumeric(Integer integer) {
return integer;
}
@Specialization
protected int getIntegralNumeric(RAbstractIntVector intVec) {
return intVec.getDataAt(0);
}
@Specialization
protected int getIntegralNumeric(Double d) {
return (int) (double) d;
}
@Specialization
protected int getIntegralNumeric(RAbstractDoubleVector doubleVec) {
return (int) doubleVec.getDataAt(0);
}
@Fallback
protected boolean isMissingOrNumericNode(Object obj) {
return false;
protected int getIntegralNumeric(@SuppressWarnings("unused") Object obj) {
throw RInternalError.shouldNotReachHere();
}
}
public static GetIntegralNumericNode createGetIntegralNumericNode() {
return GetIntegralNumericNodeGen.create();
}
public static IsIntegralNumericNode createIsIntegralNumericNodeNoLengthCheck() {
return IsIntegralNumericNodeGen.create(false);
}
public static IsIntegralNumericNode createIsIntegralNumericNodeLengthCheck() {
return IsIntegralNumericNodeGen.create(true);
}
@TypeSystemReference(RTypesFlatLayout.class)
@ImportStatic(SeqFunctions.class)
public abstract static class SeqFastPath extends FastPathAdapter {
@Specialization(guards = {"!hasClass(args, getClassAttributeNode)", "lengthSpecials(args)"})
@SuppressWarnings("unused")
protected Object seqNoClassFromAndLength(VirtualFrame frame, RArgsValuesAndNames args, //
@Cached("createSeqIntForFastPath()") SeqInt seqInt,
@Cached("lookupSeqInt()") RFunction seqIntFunction,
@Cached("createBinaryProfile()") ConditionProfile isNumericProfile,
@Cached("createGetClassAttributeNode()") GetClassAttributeNode getClassAttributeNode,
@Cached("createIsMissingOrNumericNode()") IsMissingOrNumericNode fromCheck) {
if (isNumericProfile.profile(fromCheck.execute(args.getArgument(0)))) {
if (args.getLength() == 1) {
return seqInt.execute(frame, RMissing.instance, RMissing.instance, RMissing.instance, args.getArgument(0), RMissing.instance);
} else {
return seqInt.execute(frame, args.getArgument(0), RMissing.instance, RMissing.instance, args.getArgument(1), RMissing.instance);
}
} else {
return null;
}
}
@Specialization(guards = {"!hasClass(args, getClassAttributeNode)"})
@SuppressWarnings("unused")
protected Object seqNoClassAndNumeric(VirtualFrame frame, RArgsValuesAndNames args,
......@@ -160,6 +235,13 @@ public class SeqFunctions {
/**
* The arguments are reordered if any are named, and later will be checked for missing or
* numeric.
*
* N.B: the reordering has a significant performance cost, e.g.
*
* {@code seq(1L, length.out=20L)} is MUCH slower than {@code seq(1L, , , 20L)}
*
* TODO we special case the above, as it is a common idiom, but can we improve the general
* case?
*/
public static Object[] reorderedArguments(RArgsValuesAndNames argsIn, RFunction seqIntFunction) {
RArgsValuesAndNames args = argsIn;
......@@ -188,6 +270,26 @@ public class SeqFunctions {
}
return false;
}
private static final String lengthOut = "length.out";
/**
* Guard that picks out the common idioms {@code seq(length.out=N)} and
* {@code seq(M, length.out=N)} N.B. assert: signature names are interned strings
*/
public boolean lengthSpecials(RArgsValuesAndNames args) {
int argsLen = args.getLength();
if (argsLen == 1) {
String sig0 = args.getSignature().getName(0);
return sig0 != null && sig0 == lengthOut;
} else if (argsLen == 2) {
String sig0 = args.getSignature().getName(0);
String sig1 = args.getSignature().getName(1);
return sig0 == null && sig1 != null && sig1 == lengthOut;
} else {
return false;
}
}
}
/**
......@@ -268,7 +370,7 @@ public class SeqFunctions {
* N.B. javac gives error "cannot find symbol" on plain "@RBuiltin".
*/
@TypeSystemReference(RTypesFlatLayout.class)
@ImportStatic(AsRealNodeGen.class)
@ImportStatic({AsRealNodeGen.class, SeqFunctions.class})
@com.oracle.truffle.r.runtime.builtins.RBuiltin(name = "seq.int", kind = PRIMITIVE, parameterNames = {"from", "to", "by", "length.out", "along.with",
"..."}, dispatch = INTERNAL_GENERIC, genericName = "seq", behavior = PURE)
@SuppressWarnings("unused")
......@@ -336,6 +438,15 @@ public class SeqFunctions {
return RDataFactory.createIntSequence(1, 1, getLength(frame, from));
}
/**
* A length-1 REAL. Return "1:(int) from" where from is positive integral
*/
@Specialization(guards = {"fromVec.getLength() == 1", "isPositiveIntegralDouble(fromVec.getDataAt(0))"})
protected RAbstractVector seqFromOneArgIntDouble(RAbstractDoubleVector fromVec, RMissing to, RMissing by, RMissing lengthOut, RMissing alongWith) {
int len = (int) fromVec.getDataAt(0);
return RDataFactory.createIntSequence(1, 1, len);
}
/**
* A length-1 REAL. Return "1:(int) from" (N.B. from may be negative) EXCEPT
* {@code seq(0.2)} is NOT the same as {@code seq(0.0)} (according to GNU R)
......@@ -649,49 +760,6 @@ public class SeqFunctions {
}
}
@TypeSystemReference(RTypesFlatLayout.class)
public abstract static class GetIntegralNumericNode extends Node {
public abstract int execute(Object obj);
@Specialization
protected int getIntegralNumeric(Integer integer) {
return integer;
}
@Specialization
protected int getIntegralNumeric(RAbstractIntVector intVec) {
return intVec.getDataAt(0);
}
@Specialization
protected int getIntegralNumeric(Double d) {
return (int) (double) d;
}
@Specialization
protected int getIntegralNumeric(RAbstractDoubleVector doubleVec) {
return (int) doubleVec.getDataAt(0);
}
@Fallback
protected int getIntegralNumeric(Object obj) {
throw RInternalError.shouldNotReachHere();
}
}
public static GetIntegralNumericNode createGetIntegralNumericNode() {
return GetIntegralNumericNodeGen.create();
}
public static IsIntegralNumericNode createIsIntegralNumericNodeNoLengthCheck() {
return IsIntegralNumericNodeGen.create(false);
}
public static IsIntegralNumericNode createIsIntegralNumericNodeLengthCheck() {
return IsIntegralNumericNodeGen.create(true);
}
// common idiom
@Specialization(guards = {"fromCheck.execute(fromObj)", "lengthCheck.execute(lengthOut)"})
protected RAbstractVector seqWithFromLengthIntegralNumeric(VirtualFrame frame, Object fromObj, RMissing toObj, RMissing byObj, Object lengthOut, RMissing alongWith,
......@@ -844,6 +912,11 @@ public class SeqFunctions {
return !isMissing(obj1) || !isMissing(obj2);
}
public static final boolean isPositiveIntegralDouble(double d) {
int id = (int) d;
return id == d && id > 0;
}
// Utility methods
private static boolean isFinite(double v) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment