diff --git a/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c b/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c index 9b2a545370781553d0f7157366b688ea3dae1496..78d3d632e39a1eaadb630696fb32713075387609 100644 --- a/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c +++ b/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c @@ -42,6 +42,37 @@ Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1CallRFFI_nativeSetTempDir(JNIEnv setTempDir(env, tempDir); } +JNIEXPORT jdouble JNICALL +Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1CallRFFI_exactSumFunc(JNIEnv *env, jclass c, jdoubleArray values, jboolean hasNa, jboolean naRm) { + jint length = (*env)->GetArrayLength(env, values); + jdouble* contents = (jdouble*) (*env)->GetPrimitiveArrayCritical(env, values, NULL); + + long double sum = 0; + int i = 0; + if (!hasNa) { + for (; i < length - 3; i+= 4) { + sum += contents[i]; + sum += contents[i + 1]; + sum += contents[i + 2]; + sum += contents[i + 3]; + } + } + for (; i < length; i++) { + jdouble value = contents[i]; + if (R_IsNA(value)) { + if (!naRm) { + (*env)->ReleasePrimitiveArrayCritical(env, values, contents, JNI_ABORT); + return R_NaReal; + } + } else { + sum += value; + } + } + + (*env)->ReleasePrimitiveArrayCritical(env, values, contents, JNI_ABORT); + return sum; +} + // Boilerplate methods for the actual calls typedef SEXP (*call0func)(); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java index c2ad959b6031e5fc4b0781633a844dd8cb96f458..17c735f5162aa5988c9bbfa271e74e32df72037a 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java @@ -28,15 +28,22 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.LoopConditionProfile; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode.ReduceSemantics; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNodeGen; +import com.oracle.truffle.r.runtime.FastROptions; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.ffi.RFFIFactory; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; +import com.oracle.truffle.r.runtime.ops.na.NACheck; /** * Sum has combine semantics (TBD: exactly?) and uses a reduce operation on the resulting array. @@ -44,6 +51,8 @@ import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; @RBuiltin(name = "sum", kind = PRIMITIVE, parameterNames = {"...", "na.rm"}, dispatch = SUMMARY_GROUP_GENERIC, behavior = PURE) public abstract class Sum extends RBuiltinNode { + protected static final boolean FULL_PRECISION = FastROptions.FullPrecisionSum.getBooleanValue(); + private static final ReduceSemantics semantics = new ReduceSemantics(0, 0.0, true, null, null, true, false); @Child private UnaryArithmeticReduceNode reduce = UnaryArithmeticReduceNodeGen.create(semantics, BinaryArithmetic.ADD); @@ -58,12 +67,46 @@ public abstract class Sum extends RBuiltinNode { return new Object[]{RArgsValuesAndNames.EMPTY, RRuntime.LOGICAL_FALSE}; } - @Specialization(guards = "args.getLength() == 1") + protected static boolean isRDoubleVector(Object value) { + return value instanceof RDoubleVector; + } + + @Specialization(guards = {"FULL_PRECISION", "args.getLength() == 1", "isRDoubleVector(args.getArgument(0))", "naRm == cachedNaRm"}) + protected double sumLengthOneRDoubleVector(RArgsValuesAndNames args, @SuppressWarnings("unused") boolean naRm, + @Cached("naRm") boolean cachedNaRm, + @Cached("create()") VectorLengthProfile lengthProfile, + @Cached("createCountingProfile()") LoopConditionProfile loopProfile, + @Cached("create()") NACheck na, + @Cached("createBinaryProfile()") ConditionProfile needsExactSumProfile) { + RDoubleVector vector = (RDoubleVector) args.getArgument(0); + int length = lengthProfile.profile(vector.getLength()); + + if (needsExactSumProfile.profile(length >= 3)) { + return RFFIFactory.getRFFI().getCallRFFI().exactSum(vector.getDataWithoutCopying(), !vector.isComplete(), cachedNaRm); + } else { + na.enable(vector); + loopProfile.profileCounted(length); + double sum = 0; + for (int i = 0; loopProfile.inject(i < length); i++) { + double value = vector.getDataAt(i); + if (na.check(value)) { + if (!cachedNaRm) { + return RRuntime.DOUBLE_NA; + } + } else { + sum += value; + } + } + return sum; + } + } + + @Specialization(contains = "sumLengthOneRDoubleVector", guards = "args.getLength() == 1") protected Object sumLengthOne(RArgsValuesAndNames args, boolean naRm) { return reduce.executeReduce(args.getArgument(0), naRm, false); } - @Specialization(contains = "sumLengthOne") + @Specialization(contains = {"sumLengthOneRDoubleVector", "sumLengthOne"}) protected Object sum(RArgsValuesAndNames args, boolean naRm, // @Cached("create()") Combine combine) { return reduce.executeReduce(combine.executeCombine(args), naRm, false); diff --git a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java index 70c2779c9081b931c2910f712446abeaa6de3022..3f4525fdc087e9d4c4d2292bde8ff38b32951ddf 100644 --- a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java +++ b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java @@ -119,6 +119,8 @@ public class JNI_CallRFFI implements CallRFFI { private static native void nativeSetInteractive(boolean interactive); + private static native double exactSumFunc(double[] values, boolean hasNa, boolean naRm); + private static native Object call(long address, Object[] args); private static native Object call0(long address); @@ -192,4 +194,17 @@ public class JNI_CallRFFI implements CallRFFI { } } + @Override + public double exactSum(double[] values, boolean hasNa, boolean naRm) { + if (traceEnabled()) { + traceDownCall("exactSum"); + } + try { + return exactSumFunc(values, hasNa, naRm); + } finally { + if (traceEnabled()) { + traceDownCallReturn("exactSum", null); + } + } + } } diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java index 7fc2ace02c4adf4eae73cc4b2d62686482d80f83..3927a5a8d9af182d188ea56c55105c187d531ec6 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java @@ -49,6 +49,7 @@ public enum FastROptions { PerformanceWarnings("Print FastR performance warning", false), LoadBase("Load base package", true), PrintComplexLookups("Print a message for each non-trivial variable lookup", false), + FullPrecisionSum("Use 128 bit arithmetic in sum builtin", false), LoadPkgSourcesIndex("Load R package sources index", true), InvisibleArgs("Argument writes do not trigger state transitions", true), RefCountIncrementOnly("Disable reference count decrements for experimental state transition implementation", false), diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java index 8770a46110928aa432433cce70c077fa435e68df..96d6bc3ad112816ecc3013c3b7ff26a52b2fce59 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java @@ -52,4 +52,6 @@ public interface CallRFFI { * Sets the {@code R_Interactive} FFI variable. Similar rationale to {#link setTmpDir}. */ void setInteractive(boolean interactive); + + double exactSum(double[] values, boolean hasNa, boolean naRm); }