diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java index 94bc6f9162937d8df8693ff5dc54b4098f30ae51..1308a933811bb1c670ed37ab6c168de4aa0acc5c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SeqFunctions.java @@ -29,9 +29,11 @@ import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetClassAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.SeqFunctions.SeqInt.IsIntegralNumericNode; +import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.GetIntegralNumericNodeGen; import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.IsMissingOrNumericNodeGen; +import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.IsNumericNodeGen; import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen; -import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen.GetIntegralNumericNodeGen; import com.oracle.truffle.r.nodes.builtin.base.SeqFunctionsFactory.SeqIntNodeGen.IsIntegralNumericNodeGen; import com.oracle.truffle.r.nodes.control.RLengthNode; import com.oracle.truffle.r.nodes.control.RLengthNodeGen; @@ -85,46 +87,119 @@ public class SeqFunctions { public static IsMissingOrNumericNode createIsMissingOrNumericNode() { return IsMissingOrNumericNodeGen.create(); } + + public static IsNumericNode createIsNumericNode() { + return IsNumericNodeGen.create(); + } } @TypeSystemReference(RTypesFlatLayout.class) @SuppressWarnings("unused") - public abstract static class IsMissingOrNumericNode extends Node { + public abstract static class IsNumericNode extends Node { public abstract boolean execute(Object obj); @Specialization - protected boolean isMissingOrNumericNode(RMissing obj) { + protected boolean isNumericNode(Integer obj) { return true; } @Specialization - protected boolean isMissingOrNumericNode(Integer obj) { + protected boolean isNumericNode(Double obj) { return true; } @Specialization - protected boolean isMissingOrNumericNode(Double obj) { + protected boolean isNumericNode(RAbstractIntVector obj) { return true; } @Specialization - protected boolean isMissingOrNumericNode(RAbstractIntVector obj) { + protected boolean isNumericNode(RAbstractDoubleVector obj) { return true; } + @Fallback + protected boolean isNumericNode(Object obj) { + return false; + } + } + + @TypeSystemReference(RTypesFlatLayout.class) + @SuppressWarnings("unused") + public abstract static class IsMissingOrNumericNode extends IsNumericNode { + @Specialization - protected boolean isMissingOrNumericNode(RAbstractDoubleVector obj) { + protected boolean isMissingOrNumericNode(RMissing obj) { return true; } + } + + @TypeSystemReference(RTypesFlatLayout.class) + public abstract static class GetIntegralNumericNode extends Node { + + public abstract int execute(Object obj); + + @Specialization + protected int getIntegralNumeric(Integer integer) { + return integer; + } + + @Specialization + protected int getIntegralNumeric(RAbstractIntVector intVec) { + return intVec.getDataAt(0); + } + + @Specialization + protected int getIntegralNumeric(Double d) { + return (int) (double) d; + } + + @Specialization + protected int getIntegralNumeric(RAbstractDoubleVector doubleVec) { + return (int) doubleVec.getDataAt(0); + } @Fallback - protected boolean isMissingOrNumericNode(Object obj) { - return false; + protected int getIntegralNumeric(@SuppressWarnings("unused") Object obj) { + throw RInternalError.shouldNotReachHere(); } + + } + + public static GetIntegralNumericNode createGetIntegralNumericNode() { + return GetIntegralNumericNodeGen.create(); + } + + public static IsIntegralNumericNode createIsIntegralNumericNodeNoLengthCheck() { + return IsIntegralNumericNodeGen.create(false); + } + + public static IsIntegralNumericNode createIsIntegralNumericNodeLengthCheck() { + return IsIntegralNumericNodeGen.create(true); } @TypeSystemReference(RTypesFlatLayout.class) + @ImportStatic(SeqFunctions.class) public abstract static class SeqFastPath extends FastPathAdapter { + @Specialization(guards = {"!hasClass(args, getClassAttributeNode)", "lengthSpecials(args)"}) + @SuppressWarnings("unused") + protected Object seqNoClassFromAndLength(VirtualFrame frame, RArgsValuesAndNames args, // + @Cached("createSeqIntForFastPath()") SeqInt seqInt, + @Cached("lookupSeqInt()") RFunction seqIntFunction, + @Cached("createBinaryProfile()") ConditionProfile isNumericProfile, + @Cached("createGetClassAttributeNode()") GetClassAttributeNode getClassAttributeNode, + @Cached("createIsMissingOrNumericNode()") IsMissingOrNumericNode fromCheck) { + if (isNumericProfile.profile(fromCheck.execute(args.getArgument(0)))) { + if (args.getLength() == 1) { + return seqInt.execute(frame, RMissing.instance, RMissing.instance, RMissing.instance, args.getArgument(0), RMissing.instance); + } else { + return seqInt.execute(frame, args.getArgument(0), RMissing.instance, RMissing.instance, args.getArgument(1), RMissing.instance); + } + } else { + return null; + } + } + @Specialization(guards = {"!hasClass(args, getClassAttributeNode)"}) @SuppressWarnings("unused") protected Object seqNoClassAndNumeric(VirtualFrame frame, RArgsValuesAndNames args, @@ -160,6 +235,13 @@ public class SeqFunctions { /** * The arguments are reordered if any are named, and later will be checked for missing or * numeric. + * + * N.B: the reordering has a significant performance cost, e.g. + * + * {@code seq(1L, length.out=20L)} is MUCH slower than {@code seq(1L, , , 20L)} + * + * TODO we special case the above, as it is a common idiom, but can we improve the general + * case? */ public static Object[] reorderedArguments(RArgsValuesAndNames argsIn, RFunction seqIntFunction) { RArgsValuesAndNames args = argsIn; @@ -188,6 +270,26 @@ public class SeqFunctions { } return false; } + + private static final String lengthOut = "length.out"; + + /** + * Guard that picks out the common idioms {@code seq(length.out=N)} and + * {@code seq(M, length.out=N)} N.B. assert: signature names are interned strings + */ + public boolean lengthSpecials(RArgsValuesAndNames args) { + int argsLen = args.getLength(); + if (argsLen == 1) { + String sig0 = args.getSignature().getName(0); + return sig0 != null && sig0 == lengthOut; + } else if (argsLen == 2) { + String sig0 = args.getSignature().getName(0); + String sig1 = args.getSignature().getName(1); + return sig0 == null && sig1 != null && sig1 == lengthOut; + } else { + return false; + } + } } /** @@ -268,7 +370,7 @@ public class SeqFunctions { * N.B. javac gives error "cannot find symbol" on plain "@RBuiltin". */ @TypeSystemReference(RTypesFlatLayout.class) - @ImportStatic(AsRealNodeGen.class) + @ImportStatic({AsRealNodeGen.class, SeqFunctions.class}) @com.oracle.truffle.r.runtime.builtins.RBuiltin(name = "seq.int", kind = PRIMITIVE, parameterNames = {"from", "to", "by", "length.out", "along.with", "..."}, dispatch = INTERNAL_GENERIC, genericName = "seq", behavior = PURE) @SuppressWarnings("unused") @@ -336,6 +438,15 @@ public class SeqFunctions { return RDataFactory.createIntSequence(1, 1, getLength(frame, from)); } + /** + * A length-1 REAL. Return "1:(int) from" where from is positive integral + */ + @Specialization(guards = {"fromVec.getLength() == 1", "isPositiveIntegralDouble(fromVec.getDataAt(0))"}) + protected RAbstractVector seqFromOneArgIntDouble(RAbstractDoubleVector fromVec, RMissing to, RMissing by, RMissing lengthOut, RMissing alongWith) { + int len = (int) fromVec.getDataAt(0); + return RDataFactory.createIntSequence(1, 1, len); + } + /** * A length-1 REAL. Return "1:(int) from" (N.B. from may be negative) EXCEPT * {@code seq(0.2)} is NOT the same as {@code seq(0.0)} (according to GNU R) @@ -649,49 +760,6 @@ public class SeqFunctions { } } - @TypeSystemReference(RTypesFlatLayout.class) - public abstract static class GetIntegralNumericNode extends Node { - - public abstract int execute(Object obj); - - @Specialization - protected int getIntegralNumeric(Integer integer) { - return integer; - } - - @Specialization - protected int getIntegralNumeric(RAbstractIntVector intVec) { - return intVec.getDataAt(0); - } - - @Specialization - protected int getIntegralNumeric(Double d) { - return (int) (double) d; - } - - @Specialization - protected int getIntegralNumeric(RAbstractDoubleVector doubleVec) { - return (int) doubleVec.getDataAt(0); - } - - @Fallback - protected int getIntegralNumeric(Object obj) { - throw RInternalError.shouldNotReachHere(); - } - } - - public static GetIntegralNumericNode createGetIntegralNumericNode() { - return GetIntegralNumericNodeGen.create(); - } - - public static IsIntegralNumericNode createIsIntegralNumericNodeNoLengthCheck() { - return IsIntegralNumericNodeGen.create(false); - } - - public static IsIntegralNumericNode createIsIntegralNumericNodeLengthCheck() { - return IsIntegralNumericNodeGen.create(true); - } - // common idiom @Specialization(guards = {"fromCheck.execute(fromObj)", "lengthCheck.execute(lengthOut)"}) protected RAbstractVector seqWithFromLengthIntegralNumeric(VirtualFrame frame, Object fromObj, RMissing toObj, RMissing byObj, Object lengthOut, RMissing alongWith, @@ -844,6 +912,11 @@ public class SeqFunctions { return !isMissing(obj1) || !isMissing(obj2); } + public static final boolean isPositiveIntegralDouble(double d) { + int id = (int) d; + return id == d && id > 0; + } + // Utility methods private static boolean isFinite(double v) {