Skip to content
Snippets Groups Projects
Commit 8bfe469e authored by Lukas Stadler's avatar Lukas Stadler
Browse files

use native code to implement "sum" with proper precision

parent b5321493
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,37 @@ Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1CallRFFI_nativeSetTempDir(JNIEnv ...@@ -42,6 +42,37 @@ Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1CallRFFI_nativeSetTempDir(JNIEnv
setTempDir(env, tempDir); setTempDir(env, tempDir);
} }
JNIEXPORT jdouble JNICALL
Java_com_oracle_truffle_r_runtime_ffi_jnr_JNI_1CallRFFI_exactSumFunc(JNIEnv *env, jclass c, jdoubleArray values, jboolean hasNa, jboolean naRm) {
jint length = (*env)->GetArrayLength(env, values);
jdouble* contents = (jdouble*) (*env)->GetPrimitiveArrayCritical(env, values, NULL);
long double sum = 0;
int i = 0;
if (!hasNa) {
for (; i < length - 3; i+= 4) {
sum += contents[i];
sum += contents[i + 1];
sum += contents[i + 2];
sum += contents[i + 3];
}
}
for (; i < length; i++) {
jdouble value = contents[i];
if (R_IsNA(value)) {
if (!naRm) {
(*env)->ReleasePrimitiveArrayCritical(env, values, contents, JNI_ABORT);
return R_NaReal;
}
} else {
sum += value;
}
}
(*env)->ReleasePrimitiveArrayCritical(env, values, contents, JNI_ABORT);
return sum;
}
// Boilerplate methods for the actual calls // Boilerplate methods for the actual calls
typedef SEXP (*call0func)(); typedef SEXP (*call0func)();
......
...@@ -28,15 +28,22 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; ...@@ -28,15 +28,22 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.api.profiles.LoopConditionProfile;
import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.profile.VectorLengthProfile;
import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode;
import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode.ReduceSemantics; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNode.ReduceSemantics;
import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNodeGen; import com.oracle.truffle.r.nodes.unary.UnaryArithmeticReduceNodeGen;
import com.oracle.truffle.r.runtime.FastROptions;
import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames; import com.oracle.truffle.r.runtime.data.RArgsValuesAndNames;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.ffi.RFFIFactory;
import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
import com.oracle.truffle.r.runtime.ops.na.NACheck;
/** /**
* Sum has combine semantics (TBD: exactly?) and uses a reduce operation on the resulting array. * Sum has combine semantics (TBD: exactly?) and uses a reduce operation on the resulting array.
...@@ -44,6 +51,8 @@ import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; ...@@ -44,6 +51,8 @@ import com.oracle.truffle.r.runtime.ops.BinaryArithmetic;
@RBuiltin(name = "sum", kind = PRIMITIVE, parameterNames = {"...", "na.rm"}, dispatch = SUMMARY_GROUP_GENERIC, behavior = PURE) @RBuiltin(name = "sum", kind = PRIMITIVE, parameterNames = {"...", "na.rm"}, dispatch = SUMMARY_GROUP_GENERIC, behavior = PURE)
public abstract class Sum extends RBuiltinNode { public abstract class Sum extends RBuiltinNode {
protected static final boolean FULL_PRECISION = FastROptions.FullPrecisionSum.getBooleanValue();
private static final ReduceSemantics semantics = new ReduceSemantics(0, 0.0, true, null, null, true, false); private static final ReduceSemantics semantics = new ReduceSemantics(0, 0.0, true, null, null, true, false);
@Child private UnaryArithmeticReduceNode reduce = UnaryArithmeticReduceNodeGen.create(semantics, BinaryArithmetic.ADD); @Child private UnaryArithmeticReduceNode reduce = UnaryArithmeticReduceNodeGen.create(semantics, BinaryArithmetic.ADD);
...@@ -58,12 +67,46 @@ public abstract class Sum extends RBuiltinNode { ...@@ -58,12 +67,46 @@ public abstract class Sum extends RBuiltinNode {
return new Object[]{RArgsValuesAndNames.EMPTY, RRuntime.LOGICAL_FALSE}; return new Object[]{RArgsValuesAndNames.EMPTY, RRuntime.LOGICAL_FALSE};
} }
@Specialization(guards = "args.getLength() == 1") protected static boolean isRDoubleVector(Object value) {
return value instanceof RDoubleVector;
}
@Specialization(guards = {"FULL_PRECISION", "args.getLength() == 1", "isRDoubleVector(args.getArgument(0))", "naRm == cachedNaRm"})
protected double sumLengthOneRDoubleVector(RArgsValuesAndNames args, @SuppressWarnings("unused") boolean naRm,
@Cached("naRm") boolean cachedNaRm,
@Cached("create()") VectorLengthProfile lengthProfile,
@Cached("createCountingProfile()") LoopConditionProfile loopProfile,
@Cached("create()") NACheck na,
@Cached("createBinaryProfile()") ConditionProfile needsExactSumProfile) {
RDoubleVector vector = (RDoubleVector) args.getArgument(0);
int length = lengthProfile.profile(vector.getLength());
if (needsExactSumProfile.profile(length >= 3)) {
return RFFIFactory.getRFFI().getCallRFFI().exactSum(vector.getDataWithoutCopying(), !vector.isComplete(), cachedNaRm);
} else {
na.enable(vector);
loopProfile.profileCounted(length);
double sum = 0;
for (int i = 0; loopProfile.inject(i < length); i++) {
double value = vector.getDataAt(i);
if (na.check(value)) {
if (!cachedNaRm) {
return RRuntime.DOUBLE_NA;
}
} else {
sum += value;
}
}
return sum;
}
}
@Specialization(contains = "sumLengthOneRDoubleVector", guards = "args.getLength() == 1")
protected Object sumLengthOne(RArgsValuesAndNames args, boolean naRm) { protected Object sumLengthOne(RArgsValuesAndNames args, boolean naRm) {
return reduce.executeReduce(args.getArgument(0), naRm, false); return reduce.executeReduce(args.getArgument(0), naRm, false);
} }
@Specialization(contains = "sumLengthOne") @Specialization(contains = {"sumLengthOneRDoubleVector", "sumLengthOne"})
protected Object sum(RArgsValuesAndNames args, boolean naRm, // protected Object sum(RArgsValuesAndNames args, boolean naRm, //
@Cached("create()") Combine combine) { @Cached("create()") Combine combine) {
return reduce.executeReduce(combine.executeCombine(args), naRm, false); return reduce.executeReduce(combine.executeCombine(args), naRm, false);
......
...@@ -119,6 +119,8 @@ public class JNI_CallRFFI implements CallRFFI { ...@@ -119,6 +119,8 @@ public class JNI_CallRFFI implements CallRFFI {
private static native void nativeSetInteractive(boolean interactive); private static native void nativeSetInteractive(boolean interactive);
private static native double exactSumFunc(double[] values, boolean hasNa, boolean naRm);
private static native Object call(long address, Object[] args); private static native Object call(long address, Object[] args);
private static native Object call0(long address); private static native Object call0(long address);
...@@ -192,4 +194,17 @@ public class JNI_CallRFFI implements CallRFFI { ...@@ -192,4 +194,17 @@ public class JNI_CallRFFI implements CallRFFI {
} }
} }
@Override
public double exactSum(double[] values, boolean hasNa, boolean naRm) {
if (traceEnabled()) {
traceDownCall("exactSum");
}
try {
return exactSumFunc(values, hasNa, naRm);
} finally {
if (traceEnabled()) {
traceDownCallReturn("exactSum", null);
}
}
}
} }
...@@ -49,6 +49,7 @@ public enum FastROptions { ...@@ -49,6 +49,7 @@ public enum FastROptions {
PerformanceWarnings("Print FastR performance warning", false), PerformanceWarnings("Print FastR performance warning", false),
LoadBase("Load base package", true), LoadBase("Load base package", true),
PrintComplexLookups("Print a message for each non-trivial variable lookup", false), PrintComplexLookups("Print a message for each non-trivial variable lookup", false),
FullPrecisionSum("Use 128 bit arithmetic in sum builtin", false),
LoadPkgSourcesIndex("Load R package sources index", true), LoadPkgSourcesIndex("Load R package sources index", true),
InvisibleArgs("Argument writes do not trigger state transitions", true), InvisibleArgs("Argument writes do not trigger state transitions", true),
RefCountIncrementOnly("Disable reference count decrements for experimental state transition implementation", false), RefCountIncrementOnly("Disable reference count decrements for experimental state transition implementation", false),
......
...@@ -52,4 +52,6 @@ public interface CallRFFI { ...@@ -52,4 +52,6 @@ public interface CallRFFI {
* Sets the {@code R_Interactive} FFI variable. Similar rationale to {#link setTmpDir}. * Sets the {@code R_Interactive} FFI variable. Similar rationale to {#link setTmpDir}.
*/ */
void setInteractive(boolean interactive); void setInteractive(boolean interactive);
double exactSum(double[] values, boolean hasNa, boolean naRm);
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment