From bf443363c7547f35f5b1aa5f3b14b560c8f04e7d Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Sun, 28 Aug 2016 12:45:45 +0200 Subject: [PATCH] Row and RowMeans: converted to pipelines --- .../truffle/r/nodes/builtin/base/Row.java | 18 ++++++++++-------- .../truffle/r/nodes/builtin/base/RowMeans.java | 16 +++++----------- .../truffle/r/test/ExpectedTestOutput.test | 4 ++++ .../r/test/builtins/TestBuiltin_row.java | 5 +++++ .../r/test/builtins/TestBuiltin_rowMeans.java | 2 +- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Row.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Row.java index 7f03a651e9..a250494979 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Row.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Row.java @@ -12,22 +12,30 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.size; +import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; -import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RIntVector; -import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; @RBuiltin(name = "row", kind = INTERNAL, parameterNames = {"dims"}, behavior = PURE) public abstract class Row extends RBuiltinNode { + @Override + protected void createCasts(CastBuilder casts) { + casts.arg("dims").mustBe(nullValue().not().and(integerValue()), SHOW_CALLER, RError.Message.MATRIX_LIKE_REQUIRED, "row").asIntegerVector().mustBe(size(2)); + } + @Specialization protected RIntVector col(RAbstractIntVector x) { int nrows = x.getDataAt(0); @@ -40,10 +48,4 @@ public abstract class Row extends RBuiltinNode { } return RDataFactory.createIntVector(result, RDataFactory.COMPLETE_VECTOR, new int[]{nrows, ncols}); } - - @Specialization - @TruffleBoundary - protected RIntVector col(@SuppressWarnings("unused") RNull x) { - throw RError.error(this, RError.Message.MATRIX_LIKE_REQUIRED, "row"); - } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java index 5d97a4dac5..ca940f61df 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RowMeans.java @@ -10,23 +10,23 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; +import static com.oracle.truffle.r.runtime.RError.SHOW_CALLER; +import static com.oracle.truffle.r.runtime.RError.Message.X_NUMERIC; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; -import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; -import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.ops.BinaryArithmetic; import com.oracle.truffle.r.runtime.ops.na.NACheck; @@ -39,6 +39,8 @@ public abstract class RowMeans extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { + casts.arg("X").mustBe(numericValue(), SHOW_CALLER, X_NUMERIC); + casts.arg("m").asIntegerVector().findFirst().notNA(); casts.arg("n").asIntegerVector().findFirst().notNA(); @@ -189,12 +191,4 @@ public abstract class RowMeans extends RBuiltinNode { } return RDataFactory.createDoubleVector(result, isComplete); } - - @SuppressWarnings("unused") - @Specialization - protected RDoubleVector rowMeans(RAbstractStringVector x, int rowNum, int colNum, boolean naRm) { - CompilerDirectives.transferToInterpreter(); - throw RError.error(this, RError.Message.X_NUMERIC); - } - } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 5db9a7a28e..01052b8dfe 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -41706,6 +41706,10 @@ Error in rm(tmp, inherits = "asd") : invalid 'inherits' argument 18 0.138 0.758 0.141 0.761 19 0.122 0.775 0.124 0.777 +##com.oracle.truffle.r.test.builtins.TestBuiltin_row.testArgsCasts +#.Internal(row('str')) +Error: a matrix-like object is required as argument to 'row' + ##com.oracle.truffle.r.test.builtins.TestBuiltin_row.testrow1 #argv <- list(c(14L, 14L)); .Internal(row(argv[[1]])) [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_row.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_row.java index 7f9604666f..e0fcedf859 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_row.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_row.java @@ -31,4 +31,9 @@ public class TestBuiltin_row extends TestBase { public void testrow3() { assertEval("argv <- list(0:1); .Internal(row(argv[[1]]))"); } + + @Test + public void testArgsCasts() { + assertEval(".Internal(row('str'))"); + } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java index 8e14c8b34e..8d61e42140 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/builtins/TestBuiltin_rowMeans.java @@ -58,8 +58,8 @@ public class TestBuiltin_rowMeans extends TestBase { // Whichever value(NA or NaN) is first in the row will be returned for that row. assertEval("{rowMeans(matrix(c(NA,NaN,NaN,NA),ncol=2,nrow=2))}"); + assertEval("{x<-matrix(c(\"1\",\"2\",\"3\",\"4\"),ncol=2);rowMeans(x)}"); // Error message mismatch assertEval(Ignored.Unknown, "{rowMeans(matrix(NA,NA,NA),TRUE)}"); - assertEval(Output.IgnoreErrorContext, "{x<-matrix(c(\"1\",\"2\",\"3\",\"4\"),ncol=2);rowMeans(x)}"); } } -- GitLab