diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColMeans.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColMeans.java index 518c42f361639a3b9265921b49be991dd66606ae..6ae864b0cfdb7f43aa704f9d8ff97b20891af380 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColMeans.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColMeans.java @@ -10,16 +10,10 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.api.profiles.ConditionProfile; -import com.oracle.truffle.r.nodes.builtin.CastBuilder; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -27,125 +21,14 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; -import com.oracle.truffle.r.runtime.ops.na.NACheck; //Implements .colMeans @RBuiltin(name = "colMeans", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE) -public abstract class ColMeans extends RBuiltinNode { +public abstract class ColMeans extends ColSumsBase { @Child private BinaryArithmetic add = BinaryArithmetic.ADD.create(); - private final NACheck na = NACheck.create(); - private final ConditionProfile vectorLengthProfile = ConditionProfile.createBinaryProfile(); - - @Override - protected void createCasts(CastBuilder casts) { - casts.arg("X").mustBe(numericValue(), RError.NO_CALLER, RError.Message.X_NUMERIC); - - casts.arg("m").defaultError(RError.NO_CALLER, RError.Message.INVALID_ARGUMENT, "n").asIntegerVector().findFirst().notNA(RError.NO_CALLER, RError.Message.VECTOR_SIZE_NA); - - casts.arg("n").defaultError(RError.NO_CALLER, RError.Message.INVALID_ARGUMENT, "p").asIntegerVector().findFirst().notNA(RError.NO_CALLER, RError.Message.VECTOR_SIZE_NA); - - casts.arg("na.rm").defaultError(RError.NO_CALLER, RError.Message.INVALID_ARGUMENT, "na.rm").asLogicalVector().findFirst().notNA().map(toBoolean()); - } - - private void checkVectorLength(RAbstractVector x, int rowNum, int colNum) { - if (vectorLengthProfile.profile(x.getLength() < rowNum * colNum)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - } - - protected boolean isEmptyMatrix(int rowNum, int colNum) { - return rowNum == 0 && colNum == 0; - } - - @Specialization(guards = "isEmptyMatrix(rowNum, colNum)") - @SuppressWarnings("unused") - protected RDoubleVector colMeansEmptyMatrix(Object x, int rowNum, int colNum, boolean naRm) { - return RDataFactory.createEmptyDoubleVector(); - } - - @Specialization(guards = "!naRm") - protected RDoubleVector colMeansScalarNaRmFalse(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - return RDataFactory.createDoubleVectorFromScalar(x); - } - - @Specialization(guards = "naRm") - protected RDoubleVector colMeansScalarNaRmTrue(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x) && !Double.isNaN(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(Double.NaN); - } - } - - @Specialization(guards = "!naRm") - protected RDoubleVector colMeansScalarNaRmFalse(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA); - } - } - - @Specialization(guards = "naRm") - protected RDoubleVector colMeansScalarNaRmTrue(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(Double.NaN); - } - } - - @Specialization(guards = "!naRm") - protected RDoubleVector colMeansScalarNaRmFalse(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA); - } - } - - @Specialization(guards = "naRm") - protected RDoubleVector colMeansScalarNaRmTrue(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(Double.NaN); - } - } - @Specialization(guards = "!naRm") protected RDoubleVector colMeansNaRmFalse(RAbstractDoubleVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { checkVectorLength(x, rowNum, colNum); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSums.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSums.java index 23564e03eae97ff55fa53bd7554005217f87738a..03ffd7f62e25c5a16cefdbdab66521004ff3e158 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSums.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSums.java @@ -22,17 +22,12 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ValueProfile; -import com.oracle.truffle.r.nodes.builtin.CastBuilder; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -40,125 +35,15 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; -import com.oracle.truffle.r.runtime.ops.na.NACheck; @RBuiltin(name = "colSums", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE) -public abstract class ColSums extends RBuiltinNode { +public abstract class ColSums extends ColSumsBase { @Child private BinaryArithmetic add = BinaryArithmetic.ADD.create(); - private final NACheck na = NACheck.create(); private final ConditionProfile removeNA = ConditionProfile.createBinaryProfile(); private final ValueProfile concreteVectorProfile = ValueProfile.createClassProfile(); - private final ConditionProfile vectorLengthProfile = ConditionProfile.createBinaryProfile(); - - @Override - protected void createCasts(CastBuilder casts) { - casts.arg("X").mustBe(numericValue(), RError.NO_CALLER, RError.Message.X_NUMERIC); - - casts.arg("m").defaultError(RError.NO_CALLER, RError.Message.INVALID_ARGUMENT, "n").asIntegerVector().findFirst().notNA(RError.NO_CALLER, RError.Message.VECTOR_SIZE_NA); - - casts.arg("n").defaultError(RError.NO_CALLER, RError.Message.INVALID_ARGUMENT, "p").asIntegerVector().findFirst().notNA(RError.NO_CALLER, RError.Message.VECTOR_SIZE_NA); - - casts.arg("na.rm").defaultError(RError.NO_CALLER, RError.Message.INVALID_ARGUMENT, "na.rm").asLogicalVector().findFirst().notNA().map(toBoolean()); - } - - private void checkVectorLength(RAbstractVector x, int rowNum, int colNum) { - if (vectorLengthProfile.profile(x.getLength() < rowNum * colNum)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - } - - protected boolean isEmptyMatrix(int rowNum, int colNum) { - return rowNum == 0 && colNum == 0; - } - - @Specialization(guards = "isEmptyMatrix(rowNum, colNum)") - @SuppressWarnings("unused") - protected RDoubleVector colSumsEmptyMatrix(Object x, int rowNum, int colNum, boolean naRm) { - return RDataFactory.createEmptyDoubleVector(); - } - - @Specialization(guards = "!naRm") - protected RDoubleVector colSumsScalarNaRmFalse(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - return RDataFactory.createDoubleVectorFromScalar(x); - } - - @Specialization(guards = "naRm") - protected RDoubleVector colSumsScalarNaRmTrue(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x) && !Double.isNaN(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(Double.NaN); - } - } - - @Specialization(guards = "!naRm") - protected RDoubleVector colSumsScalarNaRmFalse(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA); - } - } - - @Specialization(guards = "naRm") - protected RDoubleVector colSumsScalarNaRmTrue(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(Double.NaN); - } - } - - @Specialization(guards = "!naRm") - protected RDoubleVector colSumsScalarNaRmFalse(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA); - } - } - - @Specialization(guards = "naRm") - protected RDoubleVector colSumsScalarNaRmTrue(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - if (vectorLengthProfile.profile(rowNum * colNum > 1)) { - throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); - } - - na.enable(x); - if (!na.check(x)) { - return RDataFactory.createDoubleVectorFromScalar(x); - } else { - return RDataFactory.createDoubleVectorFromScalar(Double.NaN); - } - } @Specialization protected RDoubleVector colSums(RAbstractDoubleVector x, int rowNum, int colNum, boolean rnaParam) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSumsBase.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSumsBase.java new file mode 100644 index 0000000000000000000000000000000000000000..0b8d48c4bad0821a40bab29a1a4a308016d810bf --- /dev/null +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/ColSumsBase.java @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.base; + +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; +import static com.oracle.truffle.r.runtime.RError.Message.INVALID_ARGUMENT; + +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; +import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.ops.na.NACheck; + +/** + * Base class that provides arguments handling and validation helper methods and trivial cases + * specializations shared between {@link RowSums}, {@link RowMeans}, {@link ColMeans}, + * {@link RowSums}. + */ +public abstract class ColSumsBase extends RBuiltinNode { + + protected final NACheck na = NACheck.create(); + private final ConditionProfile vectorLengthProfile = ConditionProfile.createBinaryProfile(); + + @Override + protected final void createCasts(CastBuilder casts) { + casts.arg("X").mustBe(numericValue(), RError.SHOW_CALLER, RError.Message.X_NUMERIC); + casts.arg("m").defaultError(RError.SHOW_CALLER, INVALID_ARGUMENT, "n").asIntegerVector().findFirst().notNA(RError.NO_CALLER, RError.Message.VECTOR_SIZE_NA); + casts.arg("n").defaultError(RError.SHOW_CALLER, INVALID_ARGUMENT, "p").asIntegerVector().findFirst().notNA(RError.NO_CALLER, RError.Message.VECTOR_SIZE_NA); + casts.arg("na.rm").asLogicalVector().findFirst().notNA().map(toBoolean()); + } + + protected final void checkVectorLength(RAbstractVector x, int rowNum, int colNum) { + if (vectorLengthProfile.profile(x.getLength() < rowNum * colNum)) { + throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); + } + } + + @Specialization(guards = {"rowNum == 0", "colNum == 0"}) + @SuppressWarnings("unused") + protected final RDoubleVector doEmptyMatrix(Object x, int rowNum, int colNum, boolean naRm) { + return RDataFactory.createEmptyDoubleVector(); + } + + @Specialization(guards = "!naRm") + protected final RDoubleVector doScalarNaRmFalse(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { + checkLengthOne(rowNum, colNum); + return RDataFactory.createDoubleVectorFromScalar(x); + } + + @Specialization(guards = "naRm") + protected final RDoubleVector doScalarNaRmTrue(double x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { + checkLengthOne(rowNum, colNum); + na.enable(x); + if (!na.check(x) && !Double.isNaN(x)) { + return RDataFactory.createDoubleVectorFromScalar(x); + } else { + return RDataFactory.createDoubleVectorFromScalar(Double.NaN); + } + } + + @Specialization(guards = "!naRm") + protected final RDoubleVector doScalarNaRmFalse(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { + checkLengthOne(rowNum, colNum); + na.enable(x); + if (!na.check(x)) { + return RDataFactory.createDoubleVectorFromScalar(x); + } else { + return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA); + } + } + + @Specialization(guards = "naRm") + protected final RDoubleVector doScalarNaRmTrue(int x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { + checkLengthOne(rowNum, colNum); + na.enable(x); + if (!na.check(x)) { + return RDataFactory.createDoubleVectorFromScalar(x); + } else { + return RDataFactory.createDoubleVectorFromScalar(Double.NaN); + } + } + + @Specialization(guards = "!naRm") + protected final RDoubleVector doScalarNaRmFalse(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { + checkLengthOne(rowNum, colNum); + na.enable(x); + if (!na.check(x)) { + return RDataFactory.createDoubleVectorFromScalar(x); + } else { + return RDataFactory.createDoubleVectorFromScalar(RRuntime.DOUBLE_NA); + } + } + + @Specialization(guards = "naRm") + protected final RDoubleVector doScalarNaRmTrue(byte x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { + checkLengthOne(rowNum, colNum); + na.enable(x); + if (!na.check(x)) { + return RDataFactory.createDoubleVectorFromScalar(x); + } else { + return RDataFactory.createDoubleVectorFromScalar(Double.NaN); + } + } + + private void checkLengthOne(int rowNum, int colNum) { + if (vectorLengthProfile.profile(rowNum * colNum > 1)) { + throw RError.error(RError.NO_CALLER, RError.Message.TOO_SHORT, "X"); + } + } +} diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java index ca940f61df65a057af3aad080c4776d761f8a624..dd72109980a9ce6bc16e31f436ce7f28737f10da 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java @@ -10,185 +10,41 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; -import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; -import static com.oracle.truffle.r.runtime.RError.Message.X_NUMERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; -import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.r.nodes.builtin.CastBuilder; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; -import com.oracle.truffle.r.runtime.data.RIntVector; -import com.oracle.truffle.r.runtime.data.RLogicalVector; -import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; -import com.oracle.truffle.r.runtime.ops.na.NACheck; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; // Implements .rowMeans @RBuiltin(name = "rowMeans", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE) -public abstract class RowMeans extends RBuiltinNode { - - @Child private BinaryArithmetic add = BinaryArithmetic.ADD.create(); - private final NACheck na = NACheck.create(); - - @Override - protected void createCasts(CastBuilder casts) { - casts.arg("X").mustBe(numericValue(), SHOW_CALLER, X_NUMERIC); - - casts.arg("m").asIntegerVector().findFirst().notNA(); - - casts.arg("n").asIntegerVector().findFirst().notNA(); - - casts.arg("na.rm").asLogicalVector().findFirst().map(toBoolean()); - } - - @Specialization(guards = "!naRm") - @TruffleBoundary - protected RDoubleVector rowMeansNaRmFalse(RDoubleVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - double[] result = new double[rowNum]; - boolean isComplete = true; - na.enable(x); - nextRow: for (int i = 0; i < rowNum; i++) { - double sum = 0; - for (int c = 0; c < colNum; c++) { - double el = x.getDataAt(c * rowNum + i); - if (na.check(el)) { - result[i] = RRuntime.DOUBLE_NA; - continue nextRow; - } - if (Double.isNaN(el)) { - result[i] = Double.NaN; - isComplete = false; - continue nextRow; - } - sum = add.op(sum, el); - } - result[i] = sum / colNum; - } - return RDataFactory.createDoubleVector(result, na.neverSeenNA() && isComplete); +public abstract class RowMeans extends RowSumsBase { + @Specialization + protected RDoubleVector rowMeans(RAbstractDoubleVector x, int rowNum, int colNum, boolean naRm) { + return accumulateRows(x, rowNum, colNum, naRm, RowMeans::getMean, (v, nacheck, i) -> v.getDataAt(i)); } - @Specialization(guards = "naRm") - @TruffleBoundary - protected RDoubleVector rowMeansNaRmTrue(RDoubleVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - double[] result = new double[rowNum]; - boolean isComplete = true; - na.enable(x); - for (int i = 0; i < rowNum; i++) { - double sum = 0; - int nonNaNumCount = 0; - for (int c = 0; c < colNum; c++) { - double el = x.getDataAt(c * rowNum + i); - if (!na.check(el) && !Double.isNaN(el)) { - sum = add.op(sum, el); - nonNaNumCount++; - } - } - if (nonNaNumCount == 0) { - result[i] = Double.NaN; - isComplete = false; - } else { - result[i] = sum / nonNaNumCount; - } - } - return RDataFactory.createDoubleVector(result, isComplete); - } - - @Specialization(guards = "!naRm") - @TruffleBoundary - protected RDoubleVector rowMeansNaRmFalse(RLogicalVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - double[] result = new double[rowNum]; - na.enable(x); - nextRow: for (int i = 0; i < rowNum; i++) { - double sum = 0; - for (int c = 0; c < colNum; c++) { - byte el = x.getDataAt(c * rowNum + i); - if (na.check(el)) { - result[i] = RRuntime.DOUBLE_NA; - continue nextRow; - } - sum = add.op(sum, el); - } - result[i] = sum / colNum; - } - return RDataFactory.createDoubleVector(result, na.neverSeenNA()); + @Specialization + protected RDoubleVector rowMeans(RAbstractIntVector x, int rowNum, int colNum, boolean naRm) { + return accumulateRows(x, rowNum, colNum, naRm, RowMeans::getMean, (v, nacheck, i) -> nacheck.convertIntToDouble(v.getDataAt(i))); } - @Specialization(guards = "naRm") - @TruffleBoundary - protected RDoubleVector rowMeansNaRmTrue(RLogicalVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - double[] result = new double[rowNum]; - boolean isComplete = true; - na.enable(x); - for (int i = 0; i < rowNum; i++) { - double sum = 0; - int nonNaNumCount = 0; - for (int c = 0; c < colNum; c++) { - byte el = x.getDataAt(c * rowNum + i); - if (!na.check(el)) { - sum = add.op(sum, el); - nonNaNumCount++; - } - } - if (nonNaNumCount == 0) { - result[i] = Double.NaN; - isComplete = false; - } else { - result[i] = sum / nonNaNumCount; - } - } - return RDataFactory.createDoubleVector(result, isComplete); - } - - @Specialization(guards = "!naRm") - @TruffleBoundary - protected RDoubleVector rowMeansNaRmFalse(RIntVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - double[] result = new double[rowNum]; - na.enable(x); - nextRow: for (int i = 0; i < rowNum; i++) { - double sum = 0; - for (int c = 0; c < colNum; c++) { - int el = x.getDataAt(c * rowNum + i); - if (na.check(el)) { - result[i] = RRuntime.DOUBLE_NA; - continue nextRow; - } - sum = add.op(sum, el); - } - result[i] = sum / colNum; - } - return RDataFactory.createDoubleVector(result, na.neverSeenNA()); + @Specialization + protected RDoubleVector rowMeans(RAbstractLogicalVector x, int rowNum, int colNum, boolean naRm) { + return accumulateRows(x, rowNum, colNum, naRm, RowMeans::getMean, (v, nacheck, i) -> nacheck.convertLogicalToDouble(v.getDataAt(i))); } - @Specialization(guards = "naRm") - @TruffleBoundary - protected RDoubleVector rowMeansNaRmTrue(RIntVector x, int rowNum, int colNum, @SuppressWarnings("unused") boolean naRm) { - double[] result = new double[rowNum]; - boolean isComplete = true; - na.enable(x); - for (int i = 0; i < rowNum; i++) { - double sum = 0; - int nonNaNumCount = 0; - for (int c = 0; c < colNum; c++) { - int el = x.getDataAt(c * rowNum + i); - if (!na.check(el)) { - sum = add.op(sum, el); - nonNaNumCount++; - } - } - if (nonNaNumCount == 0) { - result[i] = Double.NaN; - isComplete = false; - } else { - result[i] = sum / nonNaNumCount; - } + private static double getMean(double sum, int notNACount) { + if (Double.isNaN(sum)) { + return sum; + } else if (notNACount == 0) { + return Double.NaN; + } else { + return sum / notNACount; } - return RDataFactory.createDoubleVector(result, isComplete); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSums.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSums.java index 3660bf75ee1d03e56c6beacf49d56d6b6da2f8a5..4fe2545e4c6436b4c6f9f347c7f78c672880f687 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSums.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSums.java @@ -13,122 +13,27 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; -import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; -import com.oracle.truffle.api.profiles.ConditionProfile; -import com.oracle.truffle.api.profiles.LoopConditionProfile; -import com.oracle.truffle.r.nodes.builtin.CastBuilder; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; -import com.oracle.truffle.r.runtime.ops.na.NACheck; @RBuiltin(name = "rowSums", kind = INTERNAL, parameterNames = {"X", "m", "n", "na.rm"}, behavior = PURE) -public abstract class RowSums extends RBuiltinNode { - - /* - * this builtin unrolls the innermost loop (calculating multiple sums at once) to optimize cache - * behavior. - */ - private static final int UNROLL = 8; - - @Child private BinaryArithmetic add = BinaryArithmetic.ADD.create(); - - private final NACheck na = NACheck.create(); - - private final ConditionProfile removeNA = ConditionProfile.createBinaryProfile(); - private final ConditionProfile remainderProfile = ConditionProfile.createBinaryProfile(); - private final LoopConditionProfile outerProfile = LoopConditionProfile.createCountingProfile(); - private final LoopConditionProfile innerProfile = LoopConditionProfile.createCountingProfile(); - - @Override - protected void createCasts(CastBuilder casts) { - casts.toInteger(1).toInteger(2); - } - - @FunctionalInterface - private interface GetFunction<T extends RAbstractVector> { - double get(T vector, NACheck na, int index); - } - - private <T extends RAbstractVector> RDoubleVector performSums(T x, int rowNum, int colNum, byte naRm, GetFunction<T> get) { - reportWork(x.getLength()); - double[] result = new double[rowNum]; - final boolean rna = removeNA.profile(naRm == RRuntime.LOGICAL_TRUE); - na.enable(x); - outerProfile.profileCounted(rowNum / 4); - innerProfile.profileCounted(colNum); - int i = 0; - // the unrolled loop cannot handle NA values - if (!na.isEnabled()) { - while (outerProfile.inject(i <= rowNum - UNROLL)) { - double[] sum = new double[UNROLL]; - int pos = i; - for (int c = 0; innerProfile.inject(c < colNum); c++) { - for (int unroll = 0; unroll < UNROLL; unroll++) { - sum[unroll] = add.op(sum[unroll], get.get(x, na, pos + unroll)); - } - pos += rowNum; - } - for (int unroll = 0; unroll < UNROLL; unroll++) { - result[i + unroll] = sum[unroll]; - } - i += UNROLL; - } - } - if (remainderProfile.profile(i < rowNum)) { - while (i < rowNum) { - double sum = 0; - int pos = i; - for (int c = 0; innerProfile.inject(c < colNum); c++) { - double el = get.get(x, na, pos); - pos += rowNum; - if (Double.isNaN(el)) { - // call check to make sure neverSeenNA is correct - na.check(el); - if (!rna) { - sum = el; - break; - } - } else { - sum = add.op(sum, el); - } - } - result[i] = sum; - i++; - } - } - return RDataFactory.createDoubleVector(result, na.neverSeenNA()); - } - - @Specialization - protected RDoubleVector rowSums(RAbstractDoubleVector x, int rowNum, int colNum, byte naRm) { - return performSums(x, rowNum, colNum, naRm, (v, nacheck, i) -> v.getDataAt(i)); - } - +public abstract class RowSums extends RowSumsBase { @Specialization - protected RDoubleVector rowSums(RAbstractIntVector x, int rowNum, int colNum, byte naRm) { - return performSums(x, rowNum, colNum, naRm, (v, nacheck, i) -> nacheck.convertIntToDouble(v.getDataAt(i))); + protected RDoubleVector rowSums(RAbstractDoubleVector x, int rowNum, int colNum, boolean naRm) { + return accumulateRows(x, rowNum, colNum, naRm, (sum, cnt) -> sum, (v, nacheck, i) -> v.getDataAt(i)); } @Specialization - protected RDoubleVector rowSums(RAbstractLogicalVector x, int rowNum, int colNum, byte naRm) { - return performSums(x, rowNum, colNum, naRm, (v, nacheck, i) -> nacheck.convertLogicalToDouble(v.getDataAt(i))); + protected RDoubleVector rowSums(RAbstractIntVector x, int rowNum, int colNum, boolean naRm) { + return accumulateRows(x, rowNum, colNum, naRm, (sum, cnt) -> sum, (v, nacheck, i) -> nacheck.convertIntToDouble(v.getDataAt(i))); } - @SuppressWarnings("unused") @Specialization - @TruffleBoundary - protected RDoubleVector rowSums(RAbstractStringVector x, int rowNum, int colNum, byte naRm) { - throw RError.error(this, RError.Message.X_NUMERIC); + protected RDoubleVector rowSums(RAbstractLogicalVector x, int rowNum, int colNum, boolean naRm) { + return accumulateRows(x, rowNum, colNum, naRm, (sum, cnt) -> sum, (v, nacheck, i) -> nacheck.convertLogicalToDouble(v.getDataAt(i))); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSumsBase.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSumsBase.java new file mode 100644 index 0000000000000000000000000000000000000000..43bfe087d83c632f61c42ac13b0f4caa009657e7 --- /dev/null +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowSumsBase.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.base; + +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.api.profiles.LoopConditionProfile; +import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; +import com.oracle.truffle.r.runtime.ops.na.NACheck; + +/** + * Implements a skeleton of an algorithm that traverses rows and accumulates their values. + */ +public abstract class RowSumsBase extends ColSumsBase { + + /* + * this builtin unrolls the innermost loop (calculating multiple sums at once) to optimize cache + * behavior. + */ + private static final int UNROLL = 8; + + @Child private BinaryArithmetic add = BinaryArithmetic.ADD.create(); + + private final ConditionProfile removeNA = ConditionProfile.createBinaryProfile(); + private final ConditionProfile remainderProfile = ConditionProfile.createBinaryProfile(); + private final LoopConditionProfile outerProfile = LoopConditionProfile.createCountingProfile(); + private final LoopConditionProfile innerProfile = LoopConditionProfile.createCountingProfile(); + + @FunctionalInterface + protected interface GetFunction<T extends RAbstractVector> { + double get(T vector, NACheck na, int index); + } + + @FunctionalInterface + protected interface FinalTransform { + double get(double sum, int notNACount); + } + + protected final <T extends RAbstractVector> RDoubleVector accumulateRows(T x, int rowNum, int colNum, boolean naRm, FinalTransform finalTransform, RowSumsBase.GetFunction<T> get) { + reportWork(x.getLength()); + double[] result = new double[rowNum]; + na.enable(x); + outerProfile.profileCounted(rowNum / 4); + innerProfile.profileCounted(colNum); + int i = 0; + // the unrolled loop cannot handle NA values + if (!na.isEnabled()) { + while (outerProfile.inject(i <= rowNum - UNROLL)) { + double[] sum = new double[UNROLL]; + int pos = i; + for (int c = 0; innerProfile.inject(c < colNum); c++) { + for (int unroll = 0; unroll < UNROLL; unroll++) { + sum[unroll] = add.op(sum[unroll], get.get(x, na, pos + unroll)); + } + pos += rowNum; + } + for (int unroll = 0; unroll < UNROLL; unroll++) { + result[i + unroll] = finalTransform.get(sum[unroll], colNum); + } + i += UNROLL; + } + } + if (remainderProfile.profile(i < rowNum)) { + while (i < rowNum) { + double sum = 0; + int pos = i; + int notNACount = 0; + for (int c = 0; innerProfile.inject(c < colNum); c++) { + double el = get.get(x, na, pos); + pos += rowNum; + if (na.check(el)) { + if (!naRm) { + sum = RRuntime.DOUBLE_NA; + break; + } + } else if (Double.isNaN(el)) { + if (!naRm) { + sum = Double.NaN; + break; + } + } else { + sum = add.op(sum, el); + notNACount++; + } + } + result[i] = finalTransform.get(sum, notNACount); + i++; + } + } + return RDataFactory.createDoubleVector(result, na.neverSeenNA()); + } +} diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowsumFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowsumFunctions.java index b0e65daa31828bff7c0c7e0a1889f3f6aeaf7feb..782ee89bcaaac6cf996fbfcc559bc6ee93df2aa2 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowsumFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowsumFunctions.java @@ -30,11 +30,12 @@ import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.ops.na.NACheck; @@ -88,7 +89,7 @@ public class RowsumFunctions { boolean complete = xv.isComplete(); if (typeProfile.profile(isInt)) { - RIntVector xi = (RIntVector) xv; + RAbstractIntVector xi = (RAbstractIntVector) xv; int[] ansi = new int[ng * p]; for (int i = 0; i < p; i++) { for (int j = 0; j < n; j++) { @@ -117,7 +118,7 @@ public class RowsumFunctions { } result = RDataFactory.createIntVector(ansi, complete, new int[]{ng, p}); } else { - RDoubleVector xd = (RDoubleVector) xv; + RAbstractDoubleVector xd = (RAbstractDoubleVector) xv; double[] ansd = new double[ng * p]; for (int i = 0; i < p; i++) { for (int j = 0; j < n; j++) { diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/na/NACheck.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/na/NACheck.java index e7cfa312c1d4f277d0e77f4b3bcf2345324f192a..0650ab020183526a1a439aaa01fec703128bfb54 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/na/NACheck.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/ops/na/NACheck.java @@ -195,6 +195,8 @@ public final class NACheck { if (checkNAorNaN(value)) { // Special case here NaN does not enable the NA check. this.enable(true); + // Note: GnuR seems to convert NaN to NaN + 0i and NA to NA, but doing it here breaks + // other things return RRuntime.createComplexNA(); } return RDataFactory.createComplex(value, 0); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index e185832122485adabd8797a831ef1d48c34afe26..75045604ebfb8c3c0b22469ba7ec5b66a8b193fe 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -671,6 +671,14 @@ Error in `Encoding<-`(`*tmp*`, value = "UTF-8") : #argv <- structure(list(year = 2002, month = 6, day = 24, hour = 0, min = 0, sec = 10), .Names = c('year', 'month', 'day', 'hour', 'min', 'sec'));do.call('ISOdatetime', argv) [1] "2002-06-24 00:00:10 CEST" +##com.oracle.truffle.r.test.builtins.TestBuiltin_Im.testIm +#Im(NaN) +[1] 0 + +##com.oracle.truffle.r.test.builtins.TestBuiltin_Im.testIm +#Im(c(NaN, 1+1i)) +[1] 0 1 + ##com.oracle.truffle.r.test.builtins.TestBuiltin_Im.testIm #{ Im(1) } [1] 0 @@ -1214,6 +1222,14 @@ NAs introduced by coercion #.Internal(RNGkind(NULL, NULL)) [1] 3 4 +##com.oracle.truffle.r.test.builtins.TestBuiltin_Re.testRe +#Re(NaN) +[1] NaN + +##com.oracle.truffle.r.test.builtins.TestBuiltin_Re.testRe +#Re(c(NaN, 1+1i)) +[1] NaN 1 + ##com.oracle.truffle.r.test.builtins.TestBuiltin_Re.testRe #{ Re(1) } [1] 1 @@ -41855,6 +41871,10 @@ Error in matrix(NA, NA, NA) : invalid 'nrow' value (too large or NA) #{rowMeans(matrix(c(NA,NaN,NaN,NA),ncol=2,nrow=2))} [1] NA NaN +##com.oracle.truffle.r.test.builtins.TestBuiltin_rowMeans.testRowMeans +#{rowMeans(matrix(c(NaN,4+5i,2+0i,5+10i),nrow=2,ncol=2), na.rm = FALSE)} +[1] NA 4.5+7.5i + ##com.oracle.truffle.r.test.builtins.TestBuiltin_rowMeans.testRowMeans #{rowMeans(matrix(c(NaN,4+5i,2+0i,5+10i),nrow=2,ncol=2), na.rm = TRUE)} [1] 2.0+0.0i 4.5+7.5i @@ -48675,10 +48695,6 @@ In matrix(7:1, nrow = 5) : [4,] 11 15 19 -##com.oracle.truffle.r.test.builtins.TestBuiltin_sweep.testSweep -#{rowMeans(matrix(c(NaN,4+5i,2+0i,5+10i),nrow=2,ncol=2), na.rm = FALSE)} -[1] NA 4.5+7.5i - ##com.oracle.truffle.r.test.builtins.TestBuiltin_sweep.testsweep1 #argv <- structure(list(x = structure(integer(0), .Dim = c(5L, 0L)), MARGIN = 2, STATS = integer(0)), .Names = c('x', 'MARGIN', 'STATS'));do.call('sweep', argv) @@ -59742,6 +59758,10 @@ logical(0) #{ x<-c(1,2,3);x+TRUE } [1] 2 3 4 +##com.oracle.truffle.r.test.library.base.TestSimpleArithmetic.testVectorsComplex +#x <- c(NaN, 3+2i); xre <- Re(x); xim <- (0+1i) * Im(x); xre + xim +[1] NA 3+2i + ##com.oracle.truffle.r.test.library.base.TestSimpleArithmetic.testVectorsComplex #{ 1:4+c(1,2+2i) } [1] 2+0i 4+2i 4+0i 6+2i diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Im.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Im.java index 82b42bc317582c0e660d61da379f5886f9e56dd3..76738ecfa18a5b432ee11fb4be04bff5fca7952d 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Im.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Im.java @@ -62,5 +62,8 @@ public class TestBuiltin_Im extends TestBase { assertEval("{ x <- c(1+2i,3-4i) ; attr(x,\"my\") <- 2 ; Im(x) }"); assertEval("{ Im(as.raw(12)) }"); + + assertEval(Ignored.ImplementationError, "Im(c(NaN, 1+1i))"); + assertEval("Im(NaN)"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Re.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Re.java index 26ce368ea0732d3e46063d439f89f6c8cbef3b6a..7eac103d968c9f9e2baac2e6176c20aa684b4c8d 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Re.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_Re.java @@ -66,5 +66,8 @@ public class TestBuiltin_Re extends TestBase { assertEval("{ x <- c(1+2i,3-4i) ; attr(x,\"my\") <- 2 ; Re(x) }"); assertEval("{ Re(as.raw(12)) }"); + + assertEval(Ignored.ImplementationError, "Re(c(NaN, 1+1i))"); + assertEval("Re(NaN)"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colMeans.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colMeans.java index f1fabfff8609d544e56df2f0920fa171452abedd..b35955c9cf4696abfab70ff619aae08fcb98bf37 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colMeans.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colMeans.java @@ -105,6 +105,9 @@ public class TestBuiltin_colMeans extends TestBase { assertEval("{colMeans(matrix(c(NA,NaN,NaN,NA),ncol=2,nrow=2))}"); assertEval("{ a = colSums(array(1:24,c(2,3,4))); colMeans(a)}"); - assertEval(Ignored.Unknown, "{colMeans(matrix(c(NaN,4+5i,2+0i,5+10i),nrow=2,ncol=2), na.rm = TRUE)}"); + // Following fails not because of colMeans implementation, but because following code does + // not work: + // x <- c(NaN, 3+2i); xre <- Re(x); xim <- (0+1i) * Im(x); xre + xim + assertEval(Ignored.ImplementationError, "{colMeans(matrix(c(NaN,4+5i,2+0i,5+10i),nrow=2,ncol=2), na.rm = TRUE)}"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colSums.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colSums.java index ebe0d3c1c9b8f24284961fa03de03264e3362ec9..9bd8089dc09d5ce4ebd46847a95652bdd6fa4f1a 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colSums.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_colSums.java @@ -12,10 +12,6 @@ package com.oracle.truffle.r.test.builtins; import org.junit.Test; -import com.oracle.truffle.api.TruffleLanguage; -import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.api.nodes.RootNode; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.test.TestBase; // Checkstyle: stop line length check @@ -95,19 +91,4 @@ public class TestBuiltin_colSums extends TestBase { // colSums on array have correct values assertEval("{ a = colSums(array(1:24,c(2,3,4))); c(a[1,1],a[2,2],a[3,3],a[3,4]) }"); } - - class RBuiltinRootNode extends RootNode { - - @Child private RBuiltinNode builtinNode; - - RBuiltinRootNode(RBuiltinNode builtinNode) { - super(TruffleLanguage.class, null, null); - this.builtinNode = builtinNode; - } - - @Override - public Object execute(VirtualFrame frame) { - return builtinNode.execute(frame); - } - } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java index 8d61e4214075bd634e359c1047addd67c48d47a8..7f85178c0b2a08ebef095085aa7f48bdd0183d04 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java @@ -59,7 +59,8 @@ public class TestBuiltin_rowMeans extends TestBase { assertEval("{rowMeans(matrix(c(NA,NaN,NaN,NA),ncol=2,nrow=2))}"); assertEval("{x<-matrix(c(\"1\",\"2\",\"3\",\"4\"),ncol=2);rowMeans(x)}"); - // Error message mismatch + assertEval("{rowMeans(matrix(c(NaN,4+5i,2+0i,5+10i),nrow=2,ncol=2), na.rm = FALSE)}"); + // Internal error in matrix(NA, NA, NA) assertEval(Ignored.Unknown, "{rowMeans(matrix(NA,NA,NA),TRUE)}"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowSums.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowSums.java index a84133cafbc91f3cb1142ae9109ee5ce11322f7c..728dda3c30e2c663b7b4b55576de79a2559a71db 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowSums.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowSums.java @@ -100,6 +100,6 @@ public class TestBuiltin_rowSums extends TestBase { // rowSums on array have correct values assertEval("{ a = rowSums(array(1:24,c(2,3,4))); c(a[1],a[2]) }"); - assertEval(Output.IgnoreErrorContext, "{x<-matrix(c(\"1\",\"2\",\"3\",\"4\"),ncol=2);rowSums(x)}"); + assertEval("{x<-matrix(c(\"1\",\"2\",\"3\",\"4\"),ncol=2);rowSums(x)}"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sweep.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sweep.java index 0053dda72df7dfcc34a63da24a56f8f9732ca04d..f3f28eb9901d0e5c3e4d8095a63329ce6fbd9549 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sweep.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_sweep.java @@ -35,7 +35,5 @@ public class TestBuiltin_sweep extends TestBase { // Correct output but warnings assertEval(Ignored.Unknown, "{ A <- matrix(1:50, nrow=4); sweep(A, 1, 5, '-') }"); assertEval(Ignored.Unknown, "{ A <- matrix(7:1, nrow=5); sweep(A, 1, -1, '*') }"); - - assertEval("{rowMeans(matrix(c(NaN,4+5i,2+0i,5+10i),nrow=2,ncol=2), na.rm = FALSE)}"); } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestSimpleArithmetic.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestSimpleArithmetic.java index 4dc124ca5c8513ffd5589e4a61be7be0b720ccfb..239cd80e0dc7ce48f3d45b4049ed7892638b3dbd 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestSimpleArithmetic.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/library/base/TestSimpleArithmetic.java @@ -307,6 +307,7 @@ public class TestSimpleArithmetic extends TestBase { public void testVectorsComplex() { assertEval("{ 1:4+c(1,2+2i) }"); assertEval("{ c(1,2+2i)+1:4 }"); + assertEval(Ignored.ImplementationError, "x <- c(NaN, 3+2i); xre <- Re(x); xim <- (0+1i) * Im(x); xre + xim"); } @Test