diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/PPSum.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/PPSum.java index 6fd2053b204f3a7d2ec7b1df29a72e0ae1c5c74c..fc90f5fc2cc3c3903c806c11ef00bfb6e5629569 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/PPSum.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/PPSum.java @@ -22,15 +22,59 @@ */ package com.oracle.truffle.r.library.stats; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.library.stats.PPSumFactory.IntgrtVecNodeGen; +import com.oracle.truffle.r.library.stats.PPSumFactory.PPSumExternalNodeGen; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; public abstract class PPSum { + public abstract static class PPSumExternal extends RExternalBuiltinNode.Arg2 { + static { + Casts casts = new Casts(PPSumExternal.class); + casts.arg(0).asDoubleVector(); + casts.arg(1).asIntegerVector().findFirst(); + } + + @Specialization(guards = "uAccess.supports(u)") + protected RDoubleVector doPPSum(RAbstractDoubleVector u, int sl, + @Cached("create()") VectorFactory factory, + @Cached("u.access()") VectorAccess uAccess) { + + RandomIterator uIter = uAccess.randomAccess(u); + int n = uAccess.getLength(uIter); + double tmp1 = 0.0; + for (int i = 1; i <= sl; i++) { + double tmp2 = 0.0; + for (int j = i; j < n; j++) { + tmp2 += uAccess.getDouble(uIter, j) * uAccess.getDouble(uIter, j - i); + } + tmp2 *= 1.0 - i / (sl + 1.0); + tmp1 += tmp2; + } + return factory.createDoubleVectorFromScalar(2.0 * tmp1 / n); + } + + @Specialization(replaces = "doPPSum") + protected RDoubleVector doPPSumGeneric(RAbstractDoubleVector u, int sl, + @Cached("create()") VectorFactory factory) { + return doPPSum(u, sl, factory, u.slowPathAccess()); + } + + public static PPSumExternal create() { + return PPSumExternalNodeGen.create(); + } + } + /** * Implementation of function 'intgrt_vec'. */ diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java index f88668a6f8346aef115289138e89fdc59b3abccc..4c85b9c1f19661da99173f55948e9751b1052939 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/foreign/CallAndExternalFunctions.java @@ -46,6 +46,7 @@ import com.oracle.truffle.r.library.stats.CutreeNodeGen; import com.oracle.truffle.r.library.stats.DoubleCentreNodeGen; import com.oracle.truffle.r.library.stats.Influence; import com.oracle.truffle.r.library.stats.PPSum; +import com.oracle.truffle.r.library.stats.PPSum.PPSumExternal; import com.oracle.truffle.r.library.stats.RMultinomNode; import com.oracle.truffle.r.library.stats.RandFunctionsNodes; import com.oracle.truffle.r.library.stats.RandFunctionsNodes.RandFunction1Node; @@ -555,6 +556,8 @@ public class CallAndExternalFunctions { // routines from core return new UnimplementedExternal(name); + case "pp_sum": + return PPSumExternal.create(); case "intgrt_vec": return PPSum.IntgrtVecNode.create(); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 31b07e1765a1bc1bb3d96297b1d9543bcb95bb48..478566456a1036954333be515326edf6188d38f0 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -159634,6 +159634,10 @@ $wt.res -5.55671071 -0.24145935 -3.28230239 2.07679287 +##com.oracle.truffle.r.test.library.stats.TestExternal_ppsum.testPPSum# +#.Call(stats:::C_pp_sum, c(1,2,3,4,5), 3) +[1] 18.6 + ##com.oracle.truffle.r.test.library.stats.TestExternal_qgamma.testQgamma# #qgamma(0, 1) [1] 0 diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_ppsum.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_ppsum.java new file mode 100644 index 0000000000000000000000000000000000000000..2aa408e487f532e268a475cc0a6986a0c82346b1 --- /dev/null +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/stats/TestExternal_ppsum.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package com.oracle.truffle.r.test.library.stats; + +import org.junit.Test; + +import com.oracle.truffle.r.test.TestBase; + +public class TestExternal_ppsum extends TestBase { + @Test + public void testPPSum() { + assertEval(".Call(stats:::C_pp_sum, c(1,2,3,4,5), 3)"); + } +}