diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mean.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mean.java index 53a91f171580ab9796ecd04558e257dbee5a5ac3..a4f4732f9fe0504cdf94bfe4348ce04350b2ea76 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mean.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Mean.java @@ -26,81 +26,79 @@ import static com.oracle.truffle.r.runtime.RDispatch.INTERNAL_GENERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE_SUMMARY; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; +import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; +@ImportStatic(RType.class) @RBuiltin(name = "mean", kind = INTERNAL, parameterNames = {"x"}, dispatch = INTERNAL_GENERIC, behavior = PURE_SUMMARY) public abstract class Mean extends RBuiltinNode.Arg1 { - private final BranchProfile emptyProfile = BranchProfile.create(); - - @Child private BinaryArithmetic add = BinaryArithmetic.ADD.createOperation(); - @Child private BinaryArithmetic div = BinaryArithmetic.DIV.createOperation(); - static { Casts.noCasts(Mean.class); } - @Specialization - protected double mean(RAbstractDoubleVector x) { - if (x.getLength() == 0) { - emptyProfile.enter(); - return Double.NaN; + @Specialization(guards = {"access.supports(x)", "access.getType() != Complex"}) + protected double meanDoubleCached(RAbstractVector x, + @Cached("x.access()") VectorAccess access, + @Cached("createBinaryProfile()") ConditionProfile emptyProfile) { + try (SequentialIterator iter = access.access(x)) { + if (emptyProfile.profile(!access.next(iter))) { + return Double.NaN; + } + double sum = 0; + do { + double value = access.getDouble(iter); + if (access.na.checkNAorNaN(value)) { + return value; + } + sum += value; + } while (access.next(iter)); + return sum / access.getLength(iter); } - double sum = x.getDataAt(0); - for (int k = 1; k < x.getLength(); k++) { - sum = add.op(sum, x.getDataAt(k)); - } - return div.op(sum, x.getLength()); } - @Specialization - protected double mean(RAbstractIntVector x) { - if (x.getLength() == 0) { - emptyProfile.enter(); - return Double.NaN; - } - double sum = x.getDataAt(0); - for (int k = 1; k < x.getLength(); k++) { - sum = add.op(sum, x.getDataAt(k)); - } - return div.op(sum, x.getLength()); + @Specialization(replaces = "meanDoubleCached", guards = "x.getRType() != Complex") + protected double meanDoubleGeneric(RAbstractVector x, + @Cached("createBinaryProfile()") ConditionProfile emptyProfile) { + return meanDoubleCached(x, x.slowPathAccess(), emptyProfile); } - @Specialization - protected double mean(RAbstractLogicalVector x) { - if (x.getLength() == 0) { - emptyProfile.enter(); - return Double.NaN; - } - double sum = x.getDataAt(0); - for (int k = 1; k < x.getLength(); k++) { - sum = add.op(sum, x.getDataAt(k)); + @Specialization(guards = {"access.supports(x)", "access.getType() == Complex"}) + protected RComplex meanComplexCached(RAbstractVector x, + @Cached("x.access()") VectorAccess access, + @Cached("createBinaryProfile()") ConditionProfile emptyProfile) { + try (SequentialIterator iter = access.access(x)) { + if (emptyProfile.profile(!access.next(iter))) { + return RComplex.valueOf(Double.NaN, Double.NaN); + } + double sumR = 0; + double sumI = 0; + do { + double valueR = access.getComplexR(iter); + double valueI = access.getComplexI(iter); + if (access.na.check(valueR, valueI)) { + return RComplex.valueOf(valueR, valueI); + } + sumR += valueR; + sumI += valueI; + } while (access.next(iter)); + int length = access.getLength(iter); + return RComplex.valueOf(sumR / length, sumI / length); } - return div.op(sum, x.getLength()); } - @Specialization - protected RComplex mean(RAbstractComplexVector x) { - if (x.getLength() == 0) { - emptyProfile.enter(); - return RDataFactory.createComplex(Double.NaN, Double.NaN); - } - RComplex sum = x.getDataAt(0); - RComplex comp; - for (int k = 1; k < x.getLength(); k++) { - comp = x.getDataAt(k); - sum = add.op(sum.getRealPart(), sum.getImaginaryPart(), comp.getRealPart(), comp.getImaginaryPart()); - } - return div.op(sum.getRealPart(), sum.getImaginaryPart(), x.getLength(), 0); + @Specialization(replaces = "meanComplexCached", guards = "x.getRType() == Complex") + protected RComplex meanComplexGeneric(RAbstractVector x, + @Cached("createBinaryProfile()") ConditionProfile emptyProfile) { + return meanComplexCached(x, x.slowPathAccess(), emptyProfile); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index a56d8071cfde4bb15226e60302e0803fd5d3f2c8..c015d5ee754805752a2d98c3a36440bb74ef9c7a 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -39127,7 +39127,7 @@ integer(0) #argv <- structure(list(x = structure(c(31, NA, NA, 31), units = 'days', class = 'difftime'), na.rm = TRUE), .Names = c('x', 'na.rm'));do.call('mean', argv) Time difference of 31 days -##com.oracle.truffle.r.test.builtins.TestBuiltin_mean.testmean2#Ignored.ImplementationError# +##com.oracle.truffle.r.test.builtins.TestBuiltin_mean.testmean2# #argv <- list(c(0.104166666666667, 0.285714285714286, 0.285714285714286, NA)); .Internal(mean(argv[[1]])) [1] NA diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_mean.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_mean.java index e0eb250d0b058e83e9325a6a904ddeb3e76b8272..450cfba4a7e0720977f648d9cb9ffe00ea3def90 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_mean.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_mean.java @@ -24,11 +24,7 @@ public class TestBuiltin_mean extends TestBase { @Test public void testmean2() { - // FIXME NA is returned by GnuR for NA input - // Expected output: [1] NA - // FastR output: [1] NaN - - assertEval(Ignored.ImplementationError, "argv <- list(c(0.104166666666667, 0.285714285714286, 0.285714285714286, NA)); .Internal(mean(argv[[1]]))"); + assertEval("argv <- list(c(0.104166666666667, 0.285714285714286, 0.285714285714286, NA)); .Internal(mean(argv[[1]]))"); } @Test