From 8bfe469ecdc2eea0a94e26ba95c490defc1b801f Mon Sep 17 00:00:00 2001
From: Lukas Stadler <lukas.stadler@oracle.com>
Date: Tue, 20 Sep 2016 13:04:36 +0200
Subject: [PATCH] use native code to implement "sum" with proper precision

---
 .../fficall/src/jni/rfficall.c                | 31 ++++++++++++
 .../truffle/r/nodes/builtin/base/Sum.java     | 47 ++++++++++++++++++-
 .../r/runtime/ffi/jnr/JNI_CallRFFI.java       | 15 ++++++
 .../truffle/r/runtime/FastROptions.java       |  1 +
 .../truffle/r/runtime/ffi/CallRFFI.java       |  2 +
 5 files changed, 94 insertions(+), 2 deletions(-)

diff --git a/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c b/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c
index 9b2a545370..78d3d632e3 100644
--- a/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c
+++ b/com.oracle.truffle.r.native/fficall/src/jni/rfficall.c
@@ -42,6 +42,37 @@ Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1CallRFFI_nativeSetTempDir(JNIEnv
 	setTempDir(env, tempDir);
 }
 
+JNIEXPORT jdouble JNICALL
+Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1CallRFFI_exactSumFunc(JNIEnv *env, jclass c, jdoubleArray values, jboolean hasNa, jboolean naRm) {
+	jint length = (*env)->GetArrayLength(env, values);
+	jdouble* contents = (jdouble*) (*env)->GetPrimitiveArrayCritical(env, values, NULL);
+
+	long double sum = 0;
+	int i = 0;
+	if (!hasNa) {
+		for (; i < length - 3; i+= 4) {
+			sum += contents[i];
+			sum += contents[i + 1];
+			sum += contents[i + 2];
+			sum += contents[i + 3];
+		}
+	}
+	for (; i < length; i++) {
+		jdouble value = contents[i];
+		if (R_IsNA(value)) {
+			if (!naRm) {
+				(*env)->ReleasePrimitiveArrayCritical(env, values, contents, JNI_ABORT);
+				return R_NaReal;
+			}
+		} else {
+			sum += value;
+		}
+	}
+
+	(*env)->ReleasePrimitiveArrayCritical(env, values, contents, JNI_ABORT);
+	return sum;
+}
+
 // Boilerplate methods for the actual calls
 
 typedef SEXP (*call0func)();
diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java
index c2ad959b60..17c735f516 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java
@@ -28,15 +28,22 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
 
 import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
+import com.oracle.truffle.api.profiles.ConditionProfile;
+import com.oracle.truffle.api.profiles.LoopConditionProfile;
 import com.oracle.truffle.r.nodes.builtin.CastBuilder;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
+import com.oracle.truffle.r.nodes.profile.VectorLengthProfile;
 import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode;
 import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode.ReduceSemantics;
 import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNodeGen;
+import com.oracle.truffle.r.runtime.FastROptions;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
 import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames;
+import com.oracle.truffle.r.runtime.data.RDoubleVector;
+import com.oracle.truffle.r.runtime.ffi.RFFIFactory;
 import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
+import com.oracle.truffle.r.runtime.ops.na.NACheck;
 
 /**
  * Sum has combine semantics (TBD: exactly?) and uses a reduce operation on the resulting array.
@@ -44,6 +51,8 @@ import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
 @RBuiltin(name = "sum", kind = PRIMITIVE, parameterNames = {"...", "na.rm"}, dispatch = SUMMARY_GROUP_GENERIC, behavior = PURE)
 public abstract class Sum extends RBuiltinNode {
 
+    protected static final boolean FULL_PRECISION = FastROptions.FullPrecisionSum.getBooleanValue();
+
     private static final ReduceSemantics semantics = new ReduceSemantics(0, 0.0, true, null, null, true, false);
 
     @Child private UnaryArithmeticReduceNode reduce = UnaryArithmeticReduceNodeGen.create(semantics, BinaryArithmetic.ADD);
@@ -58,12 +67,46 @@ public abstract class Sum extends RBuiltinNode {
         return new Object[]{RArgsValuesAndNames.EMPTY, RRuntime.LOGICAL_FALSE};
     }
 
-    @Specialization(guards = "args.getLength() == 1")
+    protected static boolean isRDoubleVector(Object value) {
+        return value instanceof RDoubleVector;
+    }
+
+    @Specialization(guards = {"FULL_PRECISION", "args.getLength() == 1", "isRDoubleVector(args.getArgument(0))", "naRm == cachedNaRm"})
+    protected double sumLengthOneRDoubleVector(RArgsValuesAndNames args, @SuppressWarnings("unused") boolean naRm,
+                    @Cached("naRm") boolean cachedNaRm,
+                    @Cached("create()") VectorLengthProfile lengthProfile,
+                    @Cached("createCountingProfile()") LoopConditionProfile loopProfile,
+                    @Cached("create()") NACheck na,
+                    @Cached("createBinaryProfile()") ConditionProfile needsExactSumProfile) {
+        RDoubleVector vector = (RDoubleVector) args.getArgument(0);
+        int length = lengthProfile.profile(vector.getLength());
+
+        if (needsExactSumProfile.profile(length >= 3)) {
+            return RFFIFactory.getRFFI().getCallRFFI().exactSum(vector.getDataWithoutCopying(), !vector.isComplete(), cachedNaRm);
+        } else {
+            na.enable(vector);
+            loopProfile.profileCounted(length);
+            double sum = 0;
+            for (int i = 0; loopProfile.inject(i < length); i++) {
+                double value = vector.getDataAt(i);
+                if (na.check(value)) {
+                    if (!cachedNaRm) {
+                        return RRuntime.DOUBLE_NA;
+                    }
+                } else {
+                    sum += value;
+                }
+            }
+            return sum;
+        }
+    }
+
+    @Specialization(contains = "sumLengthOneRDoubleVector", guards = "args.getLength() == 1")
     protected Object sumLengthOne(RArgsValuesAndNames args, boolean naRm) {
         return reduce.executeReduce(args.getArgument(0), naRm, false);
     }
 
-    @Specialization(contains = "sumLengthOne")
+    @Specialization(contains = {"sumLengthOneRDoubleVector", "sumLengthOne"})
     protected Object sum(RArgsValuesAndNames args, boolean naRm, //
                     @Cached("create()") Combine combine) {
         return reduce.executeReduce(combine.executeCombine(args), naRm, false);
diff --git a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java
index 70c2779c90..3f4525fdc0 100644
--- a/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java
+++ b/com.oracle.truffle.r.runtime.ffi/src/com/oracle/truffle/r/runtime/ffi/jnr/JNI_CallRFFI.java
@@ -119,6 +119,8 @@ public class JNI_CallRFFI implements CallRFFI {
 
     private static native void nativeSetInteractive(boolean interactive);
 
+    private static native double exactSumFunc(double[] values, boolean hasNa, boolean naRm);
+
     private static native Object call(long address, Object[] args);
 
     private static native Object call0(long address);
@@ -192,4 +194,17 @@ public class JNI_CallRFFI implements CallRFFI {
         }
     }
 
+    @Override
+    public double exactSum(double[] values, boolean hasNa, boolean naRm) {
+        if (traceEnabled()) {
+            traceDownCall("exactSum");
+        }
+        try {
+            return exactSumFunc(values, hasNa, naRm);
+        } finally {
+            if (traceEnabled()) {
+                traceDownCallReturn("exactSum", null);
+            }
+        }
+    }
 }
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java
index 7fc2ace02c..3927a5a8d9 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/FastROptions.java
@@ -49,6 +49,7 @@ public enum FastROptions {
     PerformanceWarnings("Print FastR performance warning", false),
     LoadBase("Load base package", true),
     PrintComplexLookups("Print a message for each non-trivial variable lookup", false),
+    FullPrecisionSum("Use 128 bit arithmetic in sum builtin", false),
     LoadPkgSourcesIndex("Load R package sources index", true),
     InvisibleArgs("Argument writes do not trigger state transitions", true),
     RefCountIncrementOnly("Disable reference count decrements for experimental state transition implementation", false),
diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java
index 8770a46110..96d6bc3ad1 100644
--- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java
+++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ffi/CallRFFI.java
@@ -52,4 +52,6 @@ public interface CallRFFI {
      * Sets the {@code R_Interactive} FFI variable. Similar rationale to {#link setTmpDir}.
      */
     void setInteractive(boolean interactive);
+
+    double exactSum(double[] values, boolean hasNa, boolean naRm);
 }
-- 
GitLab