diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java index a5fda290761b3ae8f4614e466a4e0bc84757dcb1..59c53a3b2bf420f933def6b0e638d534cb6256ff 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java @@ -22,42 +22,9 @@ */ package com.oracle.truffle.r.nodes.builtin; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asLogicalVector; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asStringVector; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.chain; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.constant; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.defaultValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleNA; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleToInt; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.elementAt; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.equalTo; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.findFirst; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.isFractional; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.lte; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.map; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.mustBe; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notEmpty; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notNA; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullConstant; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.scalarLogicalValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.scalarStringValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.shouldBe; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.singleElement; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.trueValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*; import static com.oracle.truffle.r.nodes.casts.CastUtils.samples; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.junit.Assert.*; import java.util.function.Function; @@ -87,10 +54,13 @@ import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RInteger; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RLogical; import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; @@ -665,6 +635,28 @@ public class CastBuilderTest { assertTrue(RRuntime.isNA((double) r)); } + @Test + public void testSample18() { + cb.arg(0, "matrix").asDoubleVector(true, true, true).mustBe(squareMatrix()); + + RIntVector vec = RDataFactory.createIntVector(new int[]{0, 1, 2, 3}, true, new int[]{2, 2}); + Object res = cast(vec); + assertTrue(res instanceof RAbstractDoubleVector); + RAbstractDoubleVector dvec = (RAbstractDoubleVector) res; + assertNotNull(dvec.getDimensions()); + assertEquals(2, dvec.getDimensions().length); + assertEquals(2, dvec.getDimensions()[0]); + assertEquals(2, dvec.getDimensions()[1]); + } + + @Test + public void testSample19() { + cb.arg(0, "matrix").asDoubleVector(true, true, true).mustBe(dimGt(1, 0)); + + RIntVector vec = RDataFactory.createIntVector(new int[]{0, 1, 2, 3}, true, new int[]{2, 2}); + cast(vec); + } + @Test public void testPreserveNonVectorFlag() { cb.arg(0, "x").asVector(true); diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefFiltersSamplers.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefFiltersSamplers.java index 4ec9ca46f6a12f44c3005b8e5772b8aca6f355ba..dcc8f6caa5c7ab51dc1c889e7d307f8276e9d863 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefFiltersSamplers.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefFiltersSamplers.java @@ -30,6 +30,7 @@ import java.util.Objects; import com.oracle.truffle.r.nodes.builtin.CastBuilder.PredefFilters; import com.oracle.truffle.r.nodes.builtin.ValuePredicateArgumentFilter; +import com.oracle.truffle.r.nodes.builtin.VectorPredicateArgumentFilter; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -101,6 +102,26 @@ public final class PredefFiltersSamplers implements PredefFilters { return new VectorPredicateArgumentFilterSampler<>("elementAt", x -> index < x.getLength() && value == (byte) (x.getDataAtAsObject(index)), false, 0, index); } + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilterSampler<T> matrix() { + return new VectorPredicateArgumentFilterSampler<>("matrix", x -> x.isMatrix(), false); + } + + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilterSampler<T> squareMatrix() { + return new VectorPredicateArgumentFilterSampler<>("squareMatrix", x -> x.isMatrix() && x.getDimensions()[0] == x.getDimensions()[1], false, 3); + } + + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilterSampler<T> dimEq(int dim, int x) { + return new VectorPredicateArgumentFilterSampler<>("dimGt", v -> v.isMatrix() && v.getDimensions().length == dim && v.getDimensions()[dim] > x, false, dim - 1); + } + + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilterSampler<T> dimGt(int dim, int x) { + return new VectorPredicateArgumentFilterSampler<>("dimGt", v -> v.isMatrix() && v.getDimensions().length > dim && v.getDimensions()[dim] > x, false, dim - 1); + } + @Override public <T extends RAbstractVector, R extends T> VectorPredicateArgumentFilterSampler<T> size(int s) { if (s == 0) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java index 4c235819a8f2c14e7dfc3e0c8271d6be1f4e3a8a..e0f8ca0a135f46fc2d2a87f9544d2d8513037015 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java @@ -252,6 +252,14 @@ public final class CastBuilder { VectorPredicateArgumentFilter<RAbstractLogicalVector> elementAt(int index, byte value); + <T extends RAbstractVector> VectorPredicateArgumentFilter<T> matrix(); + + <T extends RAbstractVector> VectorPredicateArgumentFilter<T> squareMatrix(); + + <T extends RAbstractVector> VectorPredicateArgumentFilter<T> dimEq(int dim, int x); + + <T extends RAbstractVector> VectorPredicateArgumentFilter<T> dimGt(int dim, int x); + ValuePredicateArgumentFilter<Boolean> trueValue(); ValuePredicateArgumentFilter<Boolean> falseValue(); @@ -402,6 +410,26 @@ public final class CastBuilder { return new VectorPredicateArgumentFilter<>(x -> index < x.getLength() && value == (byte) (x.getDataAtAsObject(index)), false); } + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilter<T> matrix() { + return new VectorPredicateArgumentFilter<>(x -> x.isMatrix(), false); + } + + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilter<T> squareMatrix() { + return new VectorPredicateArgumentFilter<>(x -> x.isMatrix() && x.getDimensions()[0] == x.getDimensions()[1], false); + } + + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilter<T> dimEq(int dim, int x) { + return new VectorPredicateArgumentFilter<>(v -> v.isMatrix() && v.getDimensions().length > dim && v.getDimensions()[dim] == x, false); + } + + @Override + public <T extends RAbstractVector> VectorPredicateArgumentFilter<T> dimGt(int dim, int x) { + return new VectorPredicateArgumentFilter<>(v -> v.isMatrix() && v.getDimensions().length > dim && v.getDimensions()[dim] > x, false); + } + @Override public ValuePredicateArgumentFilter<Boolean> trueValue() { return ValuePredicateArgumentFilter.fromLambda(x -> x); @@ -712,6 +740,18 @@ public final class CastBuilder { return predefMappers; } + public static <T> ArgumentValueFilter<T> not(ArgumentValueFilter<T> filter) { + return filter.not(); + } + + public static <T> ArgumentValueFilter<T> and(ArgumentValueFilter<T> filter1, ArgumentValueFilter<T> filter2) { + return filter1.and(filter2); + } + + public static <T> ArgumentValueFilter<T> or(ArgumentValueFilter<T> filter1, ArgumentValueFilter<T> filter2) { + return filter1.or(filter2); + } + public static <T> Function<ArgCastBuilder<T, ?>, CastNode> mustBe(ArgumentFilter<?, ?> argFilter, RBaseNode callObj, boolean boxPrimitives, RError.Message message, Object... messageArgs) { return phaseBuilder -> FilterNode.create(argFilter, false, callObj, message, messageArgs, boxPrimitives); } @@ -895,6 +935,22 @@ public final class CastBuilder { return predefFilters().elementAt(index, value); } + public static <T extends RAbstractVector> VectorPredicateArgumentFilter<T> matrix() { + return predefFilters().matrix(); + } + + public static <T extends RAbstractVector> VectorPredicateArgumentFilter<T> squareMatrix() { + return predefFilters().squareMatrix(); + } + + public static <T extends RAbstractVector> VectorPredicateArgumentFilter<T> dimEq(int dim, int x) { + return predefFilters().dimEq(dim, x); + } + + public static <T extends RAbstractVector> VectorPredicateArgumentFilter<T> dimGt(int dim, int x) { + return predefFilters().dimGt(dim, x); + } + public static ValuePredicateArgumentFilter<Boolean> trueValue() { return predefFilters().trueValue(); } @@ -1509,11 +1565,15 @@ public final class CastBuilder { return asDoubleVector(false, false, false); } - default CoercedPhaseBuilder<RAbstractLogicalVector, Byte> asLogicalVector() { - state().castBuilder().toLogical(state().index()); + default CoercedPhaseBuilder<RAbstractDoubleVector, Byte> asLogicalVector(boolean preserveNames, boolean dimensionsPreservation, boolean attrPreservation) { + state().castBuilder().insert(state().index(), CastLogicalNodeGen.create(preserveNames, dimensionsPreservation, attrPreservation)); return state().factory.newCoercedPhaseBuilder(this, Byte.class); } + default CoercedPhaseBuilder<RAbstractDoubleVector, Byte> asLogicalVector() { + return asLogicalVector(false, false, false); + } + default CoercedPhaseBuilder<RAbstractStringVector, String> asStringVector(boolean preserveNames, boolean dimensionsPreservation, boolean attrPreservation) { state().castBuilder().toCharacter(state().index(), preserveNames, dimensionsPreservation, attrPreservation); return state().factory.newCoercedPhaseBuilder(this, String.class);