From fbec5d3882e2a16123fb517a29ec21873ca1aeb5 Mon Sep 17 00:00:00 2001 From: Zbynek Slajchrt <zbynek.slajchrt@oracle.com> Date: Thu, 26 Jan 2017 17:50:33 +0100 Subject: [PATCH] Forwarded types analyser for cast pipelines added --- .../truffle/r/library/stats/Covcor.java | 2 +- .../oracle/truffle/r/library/utils/Rprof.java | 1 - .../truffle/r/library/utils/TypeConvert.java | 1 - .../r/nodes/builtin/base/AsCharacter.java | 2 +- .../r/nodes/builtin/base/AsComplex.java | 2 +- .../r/nodes/builtin/base/AsDouble.java | 4 +- .../r/nodes/builtin/base/AsInteger.java | 4 +- .../r/nodes/builtin/base/AsLogical.java | 4 +- .../truffle/r/nodes/builtin/base/Bind.java | 1 - .../nodes/builtin/base/BrowserFunctions.java | 4 +- .../r/nodes/builtin/base/GetFunctions.java | 6 +- .../truffle/r/nodes/builtin/base/GetText.java | 4 +- .../r/nodes/builtin/base/IntToBits.java | 4 +- .../r/nodes/builtin/base/IntToUtf8.java | 4 +- .../r/nodes/builtin/base/IsTypeFunctions.java | 5 +- .../r/nodes/builtin/base/MatchArg.java | 4 +- .../truffle/r/nodes/builtin/base/NZChar.java | 4 +- .../truffle/r/nodes/builtin/base/Parse.java | 2 +- .../r/nodes/builtin/base/Quantifier.java | 1 - .../truffle/r/nodes/builtin/base/Sample.java | 4 +- .../r/nodes/builtin/base/SetS4Object.java | 2 +- .../truffle/r/nodes/builtin/base/Slot.java | 4 +- .../truffle/r/nodes/builtin/base/Sum.java | 2 +- .../truffle/r/nodes/builtin/base/UnClass.java | 2 +- .../nodes/builtin/base/UpdateAttributes.java | 3 +- .../r/nodes/builtin/base/UpdateClass.java | 2 +- .../r/nodes/builtin/base/UpdateLength.java | 4 +- .../r/nodes/builtin/base/UpdateSlot.java | 2 +- .../r/nodes/builtin/base/WhichFunctions.java | 4 +- .../r/nodes/builtin/CastBuilderTest.java | 84 +++- .../r/nodes/casts/CastNodeSampler.java | 21 +- .../r/nodes/casts/FilterSamplerFactory.java | 16 +- .../test/ForwardedValuesAnalyserTest.java | 409 ++++++++++++++++++ .../r/nodes/test/PipelineToCastNodeTests.java | 5 +- .../r/nodes/unary/BypassNodeGenSampler.java | 4 +- .../unary/CastToVectorNodeGenSampler.java | 2 +- .../r/nodes/unary/ChainedCastNodeSampler.java | 8 +- .../unary/ConditionalMapNodeGenSampler.java | 27 +- .../r/nodes/unary/FilterNodeGenSampler.java | 4 +- .../nodes/unary/FindFirstNodeGenSampler.java | 2 +- .../truffle/r/nodes/unary/MapNodeSampler.java | 2 +- .../r/nodes/unary/NonNANodeGenSampler.java | 2 +- .../truffle/r/nodes/builtin/CastBuilder.java | 28 +- .../truffle/r/nodes/builtin/casts/Filter.java | 75 +++- .../r/nodes/builtin/casts/PipelineConfig.java | 9 +- .../r/nodes/builtin/casts/PipelineStep.java | 18 +- .../builtin/casts/PipelineToCastNode.java | 23 +- .../builtin/casts/ValueForwardingNode.java | 94 ++++ .../analysis/ForwardedValuesAnalyser.java | 317 ++++++++++++++ .../analysis/ForwardingAnalysisResult.java | 336 ++++++++++++++ .../casts/analysis/ForwardingStatus.java | 118 +++++ .../builtin/casts/fluent/CastNodeBuilder.java | 4 +- .../casts/fluent/CoercedPhaseBuilder.java | 8 +- .../casts/fluent/HeadPhaseBuilder.java | 33 +- .../casts/fluent/InitialPhaseBuilder.java | 33 +- .../builtin/casts/fluent/PipelineBuilder.java | 20 +- .../casts/fluent/PipelineConfigBuilder.java | 55 +-- .../casts/fluent/PreinitialPhaseBuilder.java | 57 +-- .../truffle/r/nodes/unary/BypassNode.java | 15 +- .../r/nodes/unary/ConditionalMapNode.java | 48 +- 60 files changed, 1751 insertions(+), 214 deletions(-) create mode 100644 com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ForwardedValuesAnalyserTest.java create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ValueForwardingNode.java create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardedValuesAnalyser.java create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingAnalysisResult.java create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingStatus.java diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Covcor.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Covcor.java index 8d692d6efc..84091dec27 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Covcor.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Covcor.java @@ -50,7 +50,7 @@ public abstract class Covcor extends RExternalBuiltinNode.Arg4 { @Override protected void createCasts(CastBuilder casts) { casts.arg(0).mustNotBeNull(SHOW_CALLER, Message.IS_NULL, "x").asDoubleVector(); - casts.arg(1).allowNull().asDoubleVector(); + casts.arg(1).asDoubleVector(); casts.arg(2).asIntegerVector().findFirst().mustBe(eq(4), this, Message.NYI, "covcor: other method than 4 not implemented."); casts.arg(3).asLogicalVector().findFirst().map(toBoolean()); } diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/Rprof.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/Rprof.java index 784bd50cfe..4a9ab559d6 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/Rprof.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/Rprof.java @@ -56,7 +56,6 @@ import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RObjectSize; import com.oracle.truffle.r.runtime.data.RTypedValue; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.instrument.InstrumentationState; import com.oracle.truffle.r.runtime.nodes.RSyntaxElement; diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/TypeConvert.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/TypeConvert.java index 81362811de..88a47ede7f 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/TypeConvert.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/utils/TypeConvert.java @@ -35,7 +35,6 @@ import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RLogicalVector; import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractVector; public abstract class TypeConvert extends RExternalBuiltinNode.Arg5 { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsCharacter.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsCharacter.java index 04ca61957f..3dc2d41575 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsCharacter.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsCharacter.java @@ -48,7 +48,7 @@ public abstract class AsCharacter extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().mapIf(instanceOf(RAbstractListVector.class).not(), asStringVector()); + casts.arg("x").mapIf(instanceOf(RAbstractListVector.class).not(), asStringVector()); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsComplex.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsComplex.java index 814a0ec235..513e3e9adb 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsComplex.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsComplex.java @@ -42,7 +42,7 @@ public abstract class AsComplex extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asComplexVector(); + casts.arg("x").asComplexVector(); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsDouble.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsDouble.java index 7f8d5c1015..88563df392 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsDouble.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsDouble.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -41,7 +41,7 @@ public abstract class AsDouble extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asDoubleVector(); + casts.arg("x").asDoubleVector(); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsInteger.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsInteger.java index 9bc2ad81dd..ab4de27010 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsInteger.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsInteger.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -41,7 +41,7 @@ public abstract class AsInteger extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asIntegerVector(); + casts.arg("x").asIntegerVector(); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsLogical.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsLogical.java index fbdc14690c..41a7bb1674 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsLogical.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/AsLogical.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -41,7 +41,7 @@ public abstract class AsLogical extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asLogicalVector(); + casts.arg("x").asLogicalVector(); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java index 4519d18124..d95c2c135e 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java @@ -32,7 +32,6 @@ import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.VirtualFrame; -import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ValueProfile; diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BrowserFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BrowserFunctions.java index 207e01ce5d..b84302c5ff 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BrowserFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BrowserFunctions.java @@ -22,7 +22,7 @@ */ package com.oracle.truffle.r.nodes.builtin.base; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.anyValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gt; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; import static com.oracle.truffle.r.runtime.RVisibility.OFF; @@ -64,7 +64,7 @@ public class BrowserFunctions { @Override protected void createCasts(CastBuilder casts) { // TODO: add support for conditions conditions - casts.arg("condition").allowNull().mustBe(anyValue().not(), RError.Message.GENERIC, "Only NULL conditions currently supported in browser"); + casts.arg("condition").mustBe(nullValue(), RError.Message.GENERIC, "Only NULL conditions currently supported in browser"); casts.arg("expr").asLogicalVector().findFirst(RRuntime.LOGICAL_FALSE).map(toBoolean()); casts.arg("skipCalls").asIntegerVector().findFirst(0); } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetFunctions.java index 5e425c3111..0fcf141e87 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetFunctions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2014, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -174,7 +174,7 @@ public class GetFunctions { casts.arg("envir").mustBe(instanceOf(REnvironment.class).or(integerValue()).or(doubleValue()).or(instanceOf(RS4Object.class))).mapIf(integerValue().or(doubleValue()), chain(asIntegerVector()).with(findFirst().integerElement()).end()); casts.arg("mode").mustBe(stringValue()).asStringVector().findFirst(); - casts.arg("inherits").allowNull().asLogicalVector().findFirst().map(toBoolean()); + casts.arg("inherits").asLogicalVector().findFirst().map(toBoolean()); } @Specialization @@ -211,7 +211,7 @@ public class GetFunctions { casts.arg("envir").mustBe(instanceOf(REnvironment.class).or(integerValue()).or(doubleValue()).or(instanceOf(RS4Object.class))).mapIf(integerValue().or(doubleValue()), chain(asIntegerVector()).with(findFirst().integerElement()).end()); casts.arg("mode").mustBe(stringValue()).asStringVector().findFirst(); - casts.arg("inherits").allowNull().asLogicalVector().findFirst().map(toBoolean()); + casts.arg("inherits").asLogicalVector().findFirst().map(toBoolean()); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetText.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetText.java index be2ae78d87..8e7d4ca5a1 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetText.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/GetText.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2014, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -38,7 +38,7 @@ public abstract class GetText extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { casts.arg("domain").asStringVector().findFirst(""); - casts.arg("args").allowNull().asStringVector(); + casts.arg("args").asStringVector(); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToBits.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToBits.java index 9b65d700b6..08ed846aea 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToBits.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToBits.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -39,7 +39,7 @@ public abstract class IntToBits extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asIntegerVector(); + casts.arg("x").asIntegerVector(); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToUtf8.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToUtf8.java index 701d18fb0d..030149ca78 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToUtf8.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IntToUtf8.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -46,7 +46,7 @@ public abstract class IntToUtf8 extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asIntegerVector(); + casts.arg("x").asIntegerVector(); casts.arg("multiple").mustNotBeNull().asLogicalVector().findFirst().map(toBoolean()); } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsTypeFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsTypeFunctions.java index b33249c592..c8337dda05 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsTypeFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsTypeFunctions.java @@ -66,6 +66,7 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; /** * Handles all builtin functions of the form {@code is.xxx}, where is {@code xxx} is a "type". @@ -77,7 +78,7 @@ public class IsTypeFunctions { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").conf(c -> c.allowNull().mustNotBeMissing(null, RError.Message.ARGUMENT_MISSING, "x")); + casts.arg("x").mustNotBeMissing((RBaseNode) null, RError.Message.ARGUMENT_MISSING, "x"); } } @@ -487,7 +488,7 @@ public class IsTypeFunctions { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").conf(c -> c.allowNull().mustNotBeMissing(null, RError.Message.ARGUMENT_MISSING, "x")); + casts.arg("x").mustNotBeMissing((RBaseNode) null, RError.Message.ARGUMENT_MISSING, "x"); casts.arg("mode").defaultError(this, RError.Message.INVALID_ARGUMENT, "mode").mustBe(stringValue()).asStringVector().mustBe(size(1)).findFirst(); } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatchArg.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatchArg.java index bc6405b87d..52cd9569f1 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatchArg.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatchArg.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -80,7 +80,7 @@ public abstract class MatchArg extends RBuiltinNode { { CastBuilder builder = new CastBuilder(); - builder.arg(0).allowNull().asStringVector(); + builder.arg(0).asStringVector(); builder.arg(1).allowMissing().mustBe(stringValue()).asStringVector(); builder.arg(2).mustBe(logicalValue()).asLogicalVector().findFirst().map(toBoolean()); this.casts = builder.getCasts(); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NZChar.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NZChar.java index c8c192044c..cd2f85f3ec 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NZChar.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NZChar.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2014, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2014, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -41,7 +41,7 @@ public abstract class NZChar extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asStringVector(); + casts.arg("x").asStringVector(); casts.arg("keepNA").asLogicalVector().findFirst(RRuntime.LOGICAL_FALSE).map(toBoolean()); } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Parse.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Parse.java index fbb66b60b7..dbf7704245 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Parse.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Parse.java @@ -106,7 +106,7 @@ public abstract class Parse extends RBuiltinNode { // Note: string is captured by the R wrapper and transformed to a file, other types not casts.arg("conn").defaultError(MUST_BE_STRING_OR_CONNECTION, "file").mustNotBeNull().asIntegerVector().findFirst(); casts.arg("n").asIntegerVector().findFirst(RRuntime.INT_NA).notNA(-1); - casts.arg("text").allowNull().asStringVector(); + casts.arg("text").asStringVector(); casts.arg("prompt").asStringVector().findFirst("?"); casts.arg("encoding").mustBe(stringValue()).asStringVector().findFirst(); } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Quantifier.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Quantifier.java index d2f9d9588b..f0d079140d 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Quantifier.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Quantifier.java @@ -36,7 +36,6 @@ import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.nodes.builtin.CastBuilder; -import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.unary.CastNode; import com.oracle.truffle.r.runtime.RError; diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java index 682bbc28bc..128290e690 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sample.java @@ -7,7 +7,7 @@ * Copyright (c) 1997-2012, The R Core Team * Copyright (c) 2003-2008, The R Foundation * Copyright (c) 2014, Purdue University - * Copyright (c) 2014, 2016, Oracle and/or its affiliates + * Copyright (c) 2014, 2017, Oracle and/or its affiliates * * All rights reserved. */ @@ -70,7 +70,7 @@ public abstract class Sample extends RBuiltinNode { notNA().mustBe(gte0()); casts.arg("replace").mustBe(integerValue().or(doubleValue()).or(logicalValue())). asLogicalVector().mustBe(singleElement()).findFirst().notNA().map(toBoolean()); - casts.arg("prob").allowNull().asDoubleVector(); + casts.arg("prob").asDoubleVector(); // @formatter:on } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetS4Object.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetS4Object.java index 4bdc429907..4fa28d6d07 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetS4Object.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/SetS4Object.java @@ -47,7 +47,7 @@ public abstract class SetS4Object extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("object").allowNull().asAttributable(true, true, true); + casts.arg("object").asAttributable(true, true, true); casts.arg("flag").asLogicalVector().mustBe(singleElement(), RError.SHOW_CALLER, RError.Message.INVALID_ARGUMENT, "flag").findFirst().map(toBoolean()); // "complete" can be a vector, unlike "flag" casts.arg("complete").asIntegerVector().findFirst(RError.SHOW_CALLER, RError.Message.INVALID_ARGUMENT, "complete"); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Slot.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Slot.java index 6c4e907480..086d7d1fae 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Slot.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Slot.java @@ -6,7 +6,7 @@ * Copyright (c) 1995, 1996, 1997 Robert Gentleman and Ross Ihaka * Copyright (c) 1995-2014, The R Core Team * Copyright (c) 2002-2008, The R Foundation - * Copyright (c) 2015, 2016, Oracle and/or its affiliates + * Copyright (c) 2015, 2017, Oracle and/or its affiliates * * All rights reserved. */ @@ -40,7 +40,7 @@ public abstract class Slot extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg(0).allowNull().asAttributable(true, true, true); + casts.arg(0).asAttributable(true, true, true); } private String getName(Object nameObj) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java index a8600926f8..e0571ca49d 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Sum.java @@ -61,7 +61,7 @@ public abstract class Sum extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("na.rm").allowNull().asLogicalVector().findFirst().map(toBoolean()); + casts.arg("na.rm").asLogicalVector().findFirst(RRuntime.LOGICAL_NA).map(toBoolean()); } @Override diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UnClass.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UnClass.java index 37c4910689..ce157c2c11 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UnClass.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UnClass.java @@ -36,7 +36,7 @@ public abstract class UnClass extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg("x").allowNull().asAttributable(true, true, true); + casts.arg("x").asAttributable(true, true, true); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java index e346d88aad..da512cd398 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateAttributes.java @@ -23,6 +23,7 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; import static com.oracle.truffle.r.runtime.RError.Message.ATTRIBUTES_LIST_OR_NULL; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; @@ -76,7 +77,7 @@ public abstract class UpdateAttributes extends RBuiltinNode { // Note: cannot check 'attributability' easily because atomic values, e.g int, are not // RAttributable. casts.arg("obj"); // by default disallows RNull - casts.arg("value").conf(c -> c.allowNull()).mustBe(instanceOf(RList.class), this, ATTRIBUTES_LIST_OR_NULL); + casts.arg("value").mustBe(nullValue().or(instanceOf(RList.class)), this, ATTRIBUTES_LIST_OR_NULL); } // it's OK for the following two methods to update attributes in-place as the container has been diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateClass.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateClass.java index 38f02a45b9..120e4b10ed 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateClass.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateClass.java @@ -54,7 +54,7 @@ public abstract class UpdateClass extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { casts.arg("x"); // disallows null - casts.arg("value").allowNull().asStringVector(); + casts.arg("value").asStringVector(); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateLength.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateLength.java index 67a5eae4cb..840f1ea37e 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateLength.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateLength.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -47,7 +47,7 @@ public abstract class UpdateLength extends RBuiltinNode { protected void createCasts(CastBuilder casts) { // Note: `length<-`(NULL, newLen) really works in GnuR unlike other update builtins // @formatter:off - casts.arg("x").conf(c -> c.allowNull()).mustBe(abstractVectorValue(), this, INVALID_UNNAMED_ARGUMENT); + casts.arg("x").allowNull().mustBe(abstractVectorValue(), this, INVALID_UNNAMED_ARGUMENT); casts.arg("value").defaultError(this, INVALID_UNNAMED_VALUE). mustBe(integerValue().or(doubleValue()).or(stringValue())). asIntegerVector().mustBe(singleElement()).findFirst(); diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateSlot.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateSlot.java index 13cc5d4bc7..d97465df82 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateSlot.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateSlot.java @@ -58,7 +58,7 @@ public abstract class UpdateSlot extends RBuiltinNode { @Override protected void createCasts(CastBuilder casts) { - casts.arg(0).allowNull().asAttributable(true, true, true); + casts.arg(0).asAttributable(true, true, true); } protected String getName(Object nameObj) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/WhichFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/WhichFunctions.java index 14614074f9..d61abdf172 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/WhichFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/WhichFunctions.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -113,7 +113,7 @@ public class WhichFunctions { @Override protected void createCasts(CastBuilder casts) { - casts.arg(0, "x").allowNull().asDoubleVector(true, false, false); + casts.arg(0, "x").asDoubleVector(true, false, false); } @Specialization diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java index f092a88c2a..531ad1205d 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/CastBuilderTest.java @@ -22,8 +22,7 @@ */ package com.oracle.truffle.r.nodes.builtin; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.abstractVectorValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asBoolean; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asInteger; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asIntegerVector; @@ -37,6 +36,7 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.constant; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.dimGt; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleNA; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleToInt; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.doubleValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.elementAt; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.emptyStringVector; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.findFirst; @@ -44,17 +44,20 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gt; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gt0; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.intNA; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.integerValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.isFractional; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalTrue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.lte; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.map; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.mustBe; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.not; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notEmpty; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notNA; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullConstant; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.numericValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.shouldBe; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.singleElement; @@ -82,6 +85,7 @@ import com.oracle.truffle.r.nodes.casts.CastNodeSampler; import com.oracle.truffle.r.nodes.casts.FilterSamplerFactory; import com.oracle.truffle.r.nodes.casts.MapperSamplerFactory; import com.oracle.truffle.r.nodes.casts.Samples; +import com.oracle.truffle.r.nodes.casts.TypeExpr; import com.oracle.truffle.r.nodes.test.TestUtilities; import com.oracle.truffle.r.nodes.test.TestUtilities.NodeHandle; import com.oracle.truffle.r.nodes.unary.CastNode; @@ -117,7 +121,7 @@ public class CastBuilderTest { private static final boolean TEST_SAMPLING = false; private CastBuilder cb; - private PreinitialPhaseBuilder<Object> arg; + private PreinitialPhaseBuilder arg; static { if (TEST_SAMPLING) { @@ -396,6 +400,9 @@ public class CastBuilderTest { @Test public void testLogicalToBooleanPipeline() { arg.asLogicalVector().findFirst(RRuntime.LOGICAL_FALSE).map(toBoolean()); + assertEquals(Boolean.TRUE, cast(RRuntime.LOGICAL_TRUE)); + assertEquals(Boolean.FALSE, cast(RRuntime.LOGICAL_FALSE)); + assertEquals(Boolean.FALSE, cast(RRuntime.LOGICAL_NA)); assertEquals(Boolean.TRUE, cast(RDataFactory.createLogicalVector(new byte[]{RRuntime.LOGICAL_TRUE, RRuntime.LOGICAL_FALSE}, true))); assertEquals(Boolean.FALSE, cast(RDataFactory.createLogicalVector(0))); testPipeline(NO_FILTER_EXPECT_EMPTY_SAMPLES); @@ -521,9 +528,9 @@ public class CastBuilderTest { @Test public void testMessageArgumentAsLambda() { Function<Object, Object> argMsg = name -> "something"; - arg.conf(c -> c.allowNull().mustNotBeMissing(SHOW_CALLER, RError.Message.GENERIC, argMsg)).mustBe(stringValue(), RError.Message.GENERIC, argMsg); + arg.mustNotBeMissing(SHOW_CALLER, RError.Message.GENERIC, argMsg).mustBe(stringValue(), RError.Message.GENERIC, argMsg); - assertCastPreserves(RNull.instance); + assertCastPreserves("abc"); assertCastFail(RMissing.instance, "something"); assertCastFail(42, "something"); } @@ -594,7 +601,7 @@ public class CastBuilderTest { @Test public void testSample22() { - arg.conf(c -> c.mapMissing(emptyStringVector()).mapNull(emptyStringVector())).mustBe(stringValue()); + arg.mapIf(nullValue().or(missingValue()), emptyStringVector()).mustBe(stringValue()); arg.mapNull(emptyStringVector()).mustBe(stringValue()); Object res = cast(RNull.instance); assertTrue(res instanceof RAbstractStringVector); @@ -606,6 +613,21 @@ public class CastBuilderTest { assertEquals("abc", res); } + @Test + public void testSample23() { + //@formatter:off + arg.defaultError(RError.Message.INVALID_UNNAMED_ARGUMENTS). + mustBe(abstractVectorValue()). + asIntegerVector(). + findFirst(RRuntime.INT_NA). + mustBe(intNA().not().and(gte(0))); + //@formatter:on + assertEquals(1, cast(1)); + assertEquals(1, cast(1)); + assertEquals(1, cast("1")); + assertCastFail(RError.Message.INVALID_UNNAMED_ARGUMENTS.message, RRuntime.INT_NA, -1, RNull.instance); + } + @Test public void testSampleNonNASequence() { arg.notNA(RError.Message.GENERIC, "Error"); @@ -624,7 +646,7 @@ public class CastBuilderTest { @Test public void testPreserveNonVectorFlag() { - arg.allowNull().asVector(true); + arg.asVector(true); assertEquals(RNull.instance, cast(RNull.instance)); } @@ -781,6 +803,43 @@ public class CastBuilderTest { Assert.assertEquals("abc", cast("abc")); } + @Test + public void testComplexFilterWithForwardedNull() { + arg.mustBe(nullValue().or(numericValue()).or(stringValue()).or(complexValue())).mapIf(numericValue().or(complexValue()), asIntegerVector()); + Assert.assertEquals(RNull.instance, cast(RNull.instance)); + Assert.assertEquals("abc", cast("abc")); + } + + @Test + public void testFindFirstOrNull() { + arg.mustBe(nullValue().or(integerValue())).asIntegerVector().findFirstOrNull(); + Assert.assertEquals(RNull.instance, cast(RNull.instance)); + Assert.assertEquals(1, cast(1)); + } + + @Test + public void testReturnIf() { + arg.returnIf(nullValue(), constant(1.1)).mustBe(logicalValue()).asLogicalVector().findFirst().map(toBoolean()); + Assert.assertEquals(1.1, cast(RNull.instance)); + Assert.assertEquals(true, cast(RRuntime.LOGICAL_TRUE)); + } + + @Test + public void testNotNullAndNotMissing() { + arg.mustBe(nullValue().not().and(missingValue().not())); + try { + cast(RNull.instance); + fail(); + } catch (Exception e) { + } + try { + cast(RMissing.instance); + fail(); + } catch (Exception e) { + } + Assert.assertEquals("abc", cast("abc")); + } + /** * Casts given object using the configured pipeline in {@link #arg}. */ @@ -813,6 +872,12 @@ public class CastBuilderTest { assertEquals("Expected warning message", expectedMessage, CastNode.getLastWarning()); } + private void assertCastFail(String expectedMessage, Object... values) { + for (Object value : values) { + assertCastFail(value, expectedMessage); + } + } + private void assertCastFail(Object value) { assertCastFail(value, String.format(RError.Message.INVALID_ARGUMENT.message, "x")); } @@ -834,6 +899,11 @@ public class CastBuilderTest { testPipeline(false); } + private TypeExpr resultTypes() { + CastNodeSampler<CastNode> sampler = CastNodeSampler.createSampler(cb.getCasts()[0]); + return sampler.resultTypes(); + } + private void testPipeline(boolean emptyPositiveSamplesAllowed) { if (!TEST_SAMPLING) { return; diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/CastNodeSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/CastNodeSampler.java index 0e64e2c2cd..880cfb0c9a 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/CastNodeSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/CastNodeSampler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -23,6 +23,8 @@ package com.oracle.truffle.r.nodes.casts; import java.lang.reflect.Constructor; +import java.util.LinkedList; +import java.util.List; import com.oracle.truffle.r.nodes.unary.CastNode; @@ -39,10 +41,15 @@ public class CastNodeSampler<T extends CastNode> { } public final TypeExpr resultTypes() { - return resultTypes(TypeExpr.ANYTHING); + SamplingContext ctx = new SamplingContext(); + TypeExpr resTypes = resultTypes(TypeExpr.ANYTHING, ctx); + for (TypeExpr altResType : ctx.altResultTypes) { + resTypes = resTypes.or(altResType); + } + return resTypes; } - public TypeExpr resultTypes(TypeExpr inputType) { + public TypeExpr resultTypes(TypeExpr inputType, SamplingContext ctx) { return CastUtils.Casts.createCastNodeCasts(castNode.getClass().getSuperclass()).narrow(inputType); } @@ -86,4 +93,12 @@ public class CastNodeSampler<T extends CastNode> { throw new IllegalArgumentException(e); } } + + public static final class SamplingContext { + private List<TypeExpr> altResultTypes = new LinkedList<>(); + + public void addAltResultType(TypeExpr altResType) { + altResultTypes.add(altResType); + } + } } diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/FilterSamplerFactory.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/FilterSamplerFactory.java index 1886cc2aaa..5341c5cef8 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/FilterSamplerFactory.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/FilterSamplerFactory.java @@ -80,26 +80,28 @@ public final class FilterSamplerFactory return x -> filter.test(x); } + @SuppressWarnings("rawtypes") @Override public ArgumentFilterSampler<?, ?> visit(TypeFilter<?, ?> filter) { + Class<?>[] filterTypes = new Class[]{filter.getType1(), filter.getType2()}; return TypePredicateArgumentFilterSampler.fromLambda(toPredicate(filter.getInstanceOfLambda()), - CastUtils.sampleValuesForClases(filter.getType()), CastUtils.samples(null), filter.getType()); + CastUtils.sampleValuesForClases(filterTypes), CastUtils.samples(null), filterTypes); } @Override public ArgumentFilterSampler<?, ?> visit(RTypeFilter<?> filter) { if (filter.getType() == RType.Integer) { - return visit(new TypeFilter<>(x -> x instanceof Integer || x instanceof RAbstractIntVector, Integer.class, RAbstractIntVector.class)); + return visit(new TypeFilter<>(Integer.class, RAbstractIntVector.class)); } else if (filter.getType() == RType.Double) { - return visit(new TypeFilter<>(x -> x instanceof Double || x instanceof RAbstractDoubleVector, Double.class, RAbstractDoubleVector.class)); + return visit(new TypeFilter<>(Double.class, RAbstractDoubleVector.class)); } else if (filter.getType() == RType.Logical) { - return visit(new TypeFilter<>(x -> x instanceof Byte || x instanceof RAbstractLogicalVector, Byte.class, RAbstractLogicalVector.class)); + return visit(new TypeFilter<>(Byte.class, RAbstractLogicalVector.class)); } else if (filter.getType() == RType.Complex) { - return visit(new TypeFilter<>(x -> x instanceof RAbstractComplexVector, RAbstractComplexVector.class)); + return visit(new TypeFilter<>(RAbstractComplexVector.class)); } else if (filter.getType() == RType.Character) { - return visit(new TypeFilter<>(x -> x instanceof String || x instanceof RAbstractStringVector, String.class, RAbstractStringVector.class)); + return visit(new TypeFilter<>(String.class, RAbstractStringVector.class)); } else if (filter.getType() == RType.Raw) { - return visit(new TypeFilter<>(x -> x instanceof RAbstractRawVector, RAbstractRawVector.class)); + return visit(new TypeFilter<>(RAbstractRawVector.class)); } else { throw RInternalError.unimplemented("TODO: more types here"); } diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ForwardedValuesAnalyserTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ForwardedValuesAnalyserTest.java new file mode 100644 index 0000000000..76fb3d0457 --- /dev/null +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ForwardedValuesAnalyserTest.java @@ -0,0 +1,409 @@ +/* + * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +import com.oracle.truffle.r.nodes.builtin.casts.Filter.AndFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.ScalarValue; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.DoubleFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.MissingFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.NotFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.NullFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.OrFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.RTypeFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.TypeFilter; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.CoercionStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FilterStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FindFirstStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapIfStep; +import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardedValuesAnalyser; +import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingAnalysisResult; +import com.oracle.truffle.r.runtime.RType; + +public class ForwardedValuesAnalyserTest { + @Test + public void testCoercion() { + PipelineStep<?, ?> firstStep = new CoercionStep<>(RType.Logical, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isLogicalForwarded()); + assertTrue(result.isNullForwarded()); + assertTrue(result.isMissingForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isStringForwarded()); + } + + @Test + public void testFindFirst() { + PipelineStep<?, ?> firstStep = new CoercionStep<>(RType.Character, false).setNext(new FindFirstStep<>("hello", String.class, null)); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isStringForwarded()); + assertFalse(result.isLogicalForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + assertFalse(result.isIntegerForwarded()); + } + + @Test + public void testRTypeFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new RTypeFilter<>(RType.Integer), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testTypeFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new TypeFilter<>(Integer.class), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testTypeFilterWithExtraCondition() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new TypeFilter<>(Integer.class, x -> x > 1), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testOrFilter1() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new OrFilter<>(new RTypeFilter<>(RType.Integer), new RTypeFilter<>(RType.Double)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertTrue(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testOrFilter2() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new OrFilter<>(new TypeFilter<>(Integer.class, x -> x > 1), new RTypeFilter<>(RType.Double)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.integerForwarded.isUnknown()); + assertTrue(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testAndFilter1() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new AndFilter<>(new RTypeFilter<>(RType.Integer), new RTypeFilter<>(RType.Double)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testNot() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new AndFilter<>(new NotFilter<>(new RTypeFilter<>(RType.Integer)), new NotFilter<>(new RTypeFilter<>(RType.Double))), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isStringForwarded()); + assertTrue(result.isComplexForwarded()); + assertTrue(result.isNullForwarded()); + assertTrue(result.isMissingForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isDoubleForwarded()); + } + + @Test + public void testAndFilter2() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new AndFilter<>(new TypeFilter<>(Integer.class, x -> x > 1), new TypeFilter<>(Integer.class, x -> x % 2 == 0)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testIntegerCoercionFollowedByFilterWithExtraCondition() { + PipelineStep<?, ?> firstStep = new CoercionStep<>(RType.Integer, false).setNext(new FindFirstStep<>(1, Integer.class, null)).setNext( + new FilterStep<>(new TypeFilter<>(Integer.class, x -> x > 1), null, false)); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareScalarValueFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.ScalarValue(1, RType.Integer)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareNAValueFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.NATest(RType.Integer)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareStringLengthFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.StringLength(1)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.stringForwarded.isUnknown()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareEmptyVectorFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.VectorSize(0)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isIntegerForwarded()); + } + + @Test + public void testCompareOneElementVectorFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.VectorSize(1)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertTrue(result.isDoubleForwarded()); + assertTrue(result.isIntegerForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareMoreElementsVectorFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.GE, new CompareFilter.VectorSize(2)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareMoreElementsVectorFilter2() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.LE, new CompareFilter.VectorSize(2)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertTrue(result.isDoubleForwarded()); + assertTrue(result.isIntegerForwarded()); + assertTrue(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareElementAt0Filter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.ElementAt(0, 1, RType.Integer)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareElementAt1Filter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.ElementAt(1, 1, RType.Integer)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testCompareDimFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new CompareFilter.Dim(0, 1)), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testDoubleFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(DoubleFilter.IS_FINITE, null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.doubleForwarded.isUnknown()); + assertFalse(result.isIntegerForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testNotNullFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new NotFilter<>(NullFilter.INSTANCE), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertTrue(result.isDoubleForwarded()); + assertTrue(result.isStringForwarded()); + assertTrue(result.isMissingForwarded()); + assertFalse(result.isNullForwarded()); + } + + @Test + public void testNotMissingFilter() { + PipelineStep<?, ?> firstStep = new FilterStep<>(new NotFilter<>(MissingFilter.INSTANCE), null, false); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertTrue(result.isDoubleForwarded()); + assertTrue(result.isStringForwarded()); + assertTrue(result.isNullForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testNoBranchReturnIf() { + PipelineStep<?, ?> firstStep = new MapIfStep<>(NullFilter.INSTANCE, null, null, true).setNext(new FilterStep<>(new RTypeFilter<>(RType.Integer), null, false)); + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isIntegerForwarded()); + assertTrue(result.isNullForwarded()); + assertFalse(result.isDoubleForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testReturnIfWithTrueBranch() { + //@formatter:off + PipelineStep<?, ?> firstStep = new MapIfStep<>(new RTypeFilter<>(RType.Integer), + new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new ScalarValue(1, RType.Integer)), null, false), null, true). + setNext(new FilterStep<>(new RTypeFilter<>(RType.Double), null, false)); + //@formatter:on + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isDoubleForwarded()); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isMissingForwarded()); + } + + @Test + public void testReturnIfWithBothBranches() { + //@formatter:off + PipelineStep<?, ?> firstStep = new MapIfStep<>(new RTypeFilter<>(RType.Integer), // the condition + // true branch + new FilterStep<>(new CompareFilter<>(CompareFilter.EQ, new ScalarValue(1, RType.Integer)), null, false), + // false branch + new FilterStep<>(new RTypeFilter<>(RType.Double), null, false), true); + //@formatter:on + + ForwardedValuesAnalyser fwdAn = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult result = fwdAn.analyse(firstStep); + assertTrue(result.isDoubleForwarded()); + assertTrue(result.integerForwarded.isUnknown()); + assertFalse(result.isNullForwarded()); + assertFalse(result.isStringForwarded()); + assertFalse(result.isMissingForwarded()); + } +} diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/PipelineToCastNodeTests.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/PipelineToCastNodeTests.java index e31dde44de..c74cd58b07 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/PipelineToCastNodeTests.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/PipelineToCastNodeTests.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -98,7 +98,7 @@ public class PipelineToCastNodeTests { @Test public void mustBeREnvironmentAsIntegerVectorFindFirst() { - CastNode pipeline = createPipeline(new FilterStep<>(new TypeFilter<>(x -> x instanceof REnvironment, REnvironment.class), null, false).setNext( + CastNode pipeline = createPipeline(new FilterStep<>(new TypeFilter<>(REnvironment.class), null, false).setNext( new CoercionStep<>(RType.Integer, false).setNext(new FindFirstStep<>("hello", String.class, null)))); CastNode chain = assertBypassNode(pipeline); assertChainedCast(chain, ChainedCastNode.class, FindFirstNode.class); @@ -126,6 +126,7 @@ public class PipelineToCastNodeTests { private static CastNode createPipeline(PipelineStep<?, ?> lastStep) { PipelineConfigBuilder configBuilder = new PipelineConfigBuilder("x"); + configBuilder.setValueForwarding(false); return PipelineToCastNode.convert(configBuilder.build(), lastStep); } } diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/BypassNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/BypassNodeGenSampler.java index 255e3ea814..8ec4b0dc84 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/BypassNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/BypassNodeGenSampler.java @@ -48,8 +48,8 @@ public class BypassNodeGenSampler extends CastNodeSampler<BypassNodeGen> { } @Override - public TypeExpr resultTypes(TypeExpr inputTypes) { - TypeExpr rt = wrappedHeadSampler == null ? TypeExpr.ANYTHING : wrappedHeadSampler.resultTypes(inputTypes); + public TypeExpr resultTypes(TypeExpr inputTypes, SamplingContext ctx) { + TypeExpr rt = wrappedHeadSampler == null ? TypeExpr.ANYTHING : wrappedHeadSampler.resultTypes(inputTypes, ctx); if (nullMapper != null) { rt = rt.or(nullMapper.resultTypes(inputTypes)); } else { diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/CastToVectorNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/CastToVectorNodeGenSampler.java index 14e7e7d935..c493163b64 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/CastToVectorNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/CastToVectorNodeGenSampler.java @@ -44,7 +44,7 @@ public class CastToVectorNodeGenSampler extends CastNodeSampler<CastToVectorNode } @Override - public TypeExpr resultTypes(TypeExpr inputType) { + public TypeExpr resultTypes(TypeExpr inputType, SamplingContext ctx) { List<Cast> castList; if (castNode.isPreserveNonVector()) { castList = Arrays.asList(new Cast(RNull.class, RNull.class), diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ChainedCastNodeSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ChainedCastNodeSampler.java index 2ea68618a3..d2a01b173d 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ChainedCastNodeSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ChainedCastNodeSampler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -39,8 +39,8 @@ public final class ChainedCastNodeSampler extends CastNodeSampler<ChainedCastNod } @Override - public TypeExpr resultTypes(TypeExpr inputTypes) { - return secondCast.resultTypes(firstCast.resultTypes(inputTypes)); + public TypeExpr resultTypes(TypeExpr inputTypes, SamplingContext ctx) { + return secondCast.resultTypes(firstCast.resultTypes(inputTypes, ctx), ctx); } @Override @@ -50,7 +50,7 @@ public final class ChainedCastNodeSampler extends CastNodeSampler<ChainedCastNod @Override public Samples<?> collectSamples(TypeExpr inputTypes, Samples<?> downStreamSamples) { - TypeExpr rt1 = firstCast.resultTypes(inputTypes); + TypeExpr rt1 = firstCast.resultTypes(inputTypes, new SamplingContext()); return firstCast.collectSamples(inputTypes, secondCast.collectSamples(rt1, downStreamSamples)); } } diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeGenSampler.java index 8ccde29473..3b8a88b017 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeGenSampler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -42,17 +42,26 @@ public class ConditionalMapNodeGenSampler extends CastNodeSampler<ConditionalMap } @Override - public TypeExpr resultTypes(TypeExpr inputType) { - return trueBranchResultTypes(inputType).or(falseBranchResultTypes(inputType)); + public TypeExpr resultTypes(TypeExpr inputType, SamplingContext ctx) { + if (castNode.isReturns()) { + ctx.addAltResultType(trueBranchResultTypes(inputType, ctx)); + return falseBranchResultTypes(inputType, ctx); + } else { + return trueBranchResultTypes(inputType, ctx).or(falseBranchResultTypes(inputType, ctx)); + } } - private TypeExpr trueBranchResultTypes(TypeExpr inputType) { - return trueBranch.resultTypes(argFilter.trueBranchType().and(inputType)); + private TypeExpr trueBranchResultTypes(TypeExpr inputType, SamplingContext ctx) { + if (trueBranch != null) { + return trueBranch.resultTypes(argFilter.trueBranchType().and(inputType), ctx); + } else { + return argFilter.trueBranchType().and(inputType); + } } - private TypeExpr falseBranchResultTypes(TypeExpr inputType) { + private TypeExpr falseBranchResultTypes(TypeExpr inputType, SamplingContext ctx) { if (falseBranch != null) { - return falseBranch.resultTypes(argFilter.falseBranchType().and(inputType)); + return falseBranch.resultTypes(argFilter.falseBranchType().and(inputType), ctx); } else { return argFilter.falseBranchType().and(inputType); } @@ -60,8 +69,8 @@ public class ConditionalMapNodeGenSampler extends CastNodeSampler<ConditionalMap @Override public Samples<?> collectSamples(TypeExpr inputType, Samples<?> downStreamSamples) { - TypeExpr trueBranchResultType = trueBranchResultTypes(inputType); - TypeExpr falseBranchResultType = falseBranchResultTypes(inputType); + TypeExpr trueBranchResultType = trueBranchResultTypes(inputType, new SamplingContext()); + TypeExpr falseBranchResultType = falseBranchResultTypes(inputType, new SamplingContext()); // filter out the incompatible samples Samples compatibleTrueBranchDownStreamSamples = downStreamSamples.filter(x -> trueBranchResultType.isInstance(x)); diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeGenSampler.java index c4f488fc00..4a73a29cce 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeGenSampler.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2013, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -43,7 +43,7 @@ public class FilterNodeGenSampler extends CastNodeSampler<FilterNodeGen> { } @Override - public TypeExpr resultTypes(TypeExpr inputType) { + public TypeExpr resultTypes(TypeExpr inputType, SamplingContext ctx) { if (isWarning) { return inputType; } else { diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FindFirstNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FindFirstNodeGenSampler.java index a7a41d91c8..2135903ab2 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FindFirstNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FindFirstNodeGenSampler.java @@ -113,7 +113,7 @@ public class FindFirstNodeGenSampler extends CastNodeSampler<FindFirstNodeGen> { } @Override - public TypeExpr resultTypes(TypeExpr inputType) { + public TypeExpr resultTypes(TypeExpr inputType, SamplingContext ctx) { TypeExpr rt; if (elementClass == null || elementClass == Object.class) { if (inputType.isAnything()) { diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeSampler.java index 3a62d8e3f1..079c4090a1 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeSampler.java @@ -39,7 +39,7 @@ public class MapNodeSampler extends CastNodeSampler<MapNode> { } @Override - public TypeExpr resultTypes(TypeExpr inputTypes) { + public TypeExpr resultTypes(TypeExpr inputTypes, SamplingContext ctx) { return mapFn.resultTypes(inputTypes); } diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/NonNANodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/NonNANodeGenSampler.java index 41d275152e..3f2f805208 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/NonNANodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/NonNANodeGenSampler.java @@ -44,7 +44,7 @@ public class NonNANodeGenSampler extends CastNodeSampler<NonNANodeGen> { } @Override - public TypeExpr resultTypes(TypeExpr inputType) { + public TypeExpr resultTypes(TypeExpr inputType, SamplingContext ctx) { return inputType; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java index 8bf41038f3..8ad506b96d 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java @@ -216,7 +216,7 @@ public final class CastBuilder { * * Analogous methods exist for {@code RMissing}. */ - public PreinitialPhaseBuilder<Object> arg(String argumentName) { + public PreinitialPhaseBuilder arg(String argumentName) { assert builtin != null : "arg(String) is only supported for builtins cast pipelines"; return getBuilder(getArgumentIndex(argumentName), argumentName).fluent(); } @@ -224,7 +224,7 @@ public final class CastBuilder { /** * @see #arg(String) */ - public PreinitialPhaseBuilder<Object> arg(int argumentIndex, String argumentName) { + public PreinitialPhaseBuilder arg(int argumentIndex, String argumentName) { assert argumentNames == null || argumentIndex >= 0 && argumentIndex < argumentBuilders.length : "argument index out of range"; assert argumentNames == null || argumentNames[argumentIndex].equals(argumentName) : "wrong argument name " + argumentName; return getBuilder(argumentIndex, argumentName).fluent(); @@ -233,7 +233,7 @@ public final class CastBuilder { /** * @see #arg(String) */ - public PreinitialPhaseBuilder<Object> arg(int argumentIndex) { + public PreinitialPhaseBuilder arg(int argumentIndex) { boolean existingIndex = argumentNames != null && argumentIndex >= 0 && argumentIndex < argumentNames.length; String name = existingIndex ? argumentNames[argumentIndex] : null; return getBuilder(argumentIndex, name).fluent(); @@ -301,13 +301,21 @@ public final class CastBuilder { } public static <T, S extends T, R> PipelineStep<T, R> mapIf(Filter<? super T, S> filter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch) { - return new MapIfStep<>(filter, trueBranch, falseBranch); + return new MapIfStep<>(filter, trueBranch, falseBranch, false); + } + + public static <T, S extends T, R> PipelineStep<T, R> returnIf(Filter<? super T, S> filter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch) { + return new MapIfStep<>(filter, trueBranch, falseBranch, true); } public static <T, S extends T, R> PipelineStep<T, R> mapIf(Filter<? super T, S> filter, PipelineStep<?, ?> trueBranch) { return mapIf(filter, trueBranch, null); } + public static <T, S extends T, R> PipelineStep<T, R> returnIf(Filter<? super T, S> filter, PipelineStep<?, ?> trueBranch) { + return returnIf(filter, trueBranch, null); + } + public static <T> ChainBuilder<T> chain(PipelineStep<T, ?> firstStep) { return new ChainBuilder<>(firstStep); } @@ -645,11 +653,11 @@ public final class CastBuilder { } public static <R> TypeFilter<Object, R> instanceOf(Class<R> cls) { - return new TypeFilter<>(x -> cls.isInstance(x), cls); + return new TypeFilter<>(cls); } public static TypeFilter<Object, RFunction> builtin() { - return new TypeFilter<>(x -> RFunction.class.isInstance(x) && ((RFunction) x).isBuiltin(), RFunction.class); + return new TypeFilter<>(RFunction.class, x -> x.isBuiltin()); } public static <R extends RAbstractIntVector> Filter<Object, R> integerValue() { @@ -676,8 +684,8 @@ public final class CastBuilder { return new RTypeFilter<>(RType.Raw); } - public static <R> TypeFilter<Object, R> anyValue() { - return new TypeFilter<>(x -> true, Object.class); + public static TypeFilter<Object, Object> anyValue() { + return new TypeFilter<>(Object.class); } @SuppressWarnings({"rawtypes", "unchecked"}) @@ -697,11 +705,11 @@ public final class CastBuilder { } public static Filter<Object, Integer> atomicIntegerValue() { - return new TypeFilter<>(x -> x instanceof String, String.class); + return new TypeFilter<>(Integer.class); } public static Filter<Object, Byte> atomicLogicalValue() { - return new TypeFilter<>(x -> x instanceof Byte, Byte.class); + return new TypeFilter<>(Byte.class); } public static MapByteToBoolean toBoolean() { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java index c74b22baa8..f26904e04c 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java @@ -22,6 +22,8 @@ */ package com.oracle.truffle.r.nodes.builtin.casts; +import java.util.function.Predicate; + import com.oracle.truffle.r.nodes.builtin.ArgumentFilter; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FilterStep; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapStep; @@ -91,16 +93,45 @@ public abstract class Filter<T, R extends T> { * Filters specific Java class. */ public static final class TypeFilter<T, R extends T> extends Filter<T, R> { - private final Class<?>[] type; - private final ArgumentFilter<Object, Object> instanceOfLambda; + private final Class<?> type1; + private final Class<?> type2; + private final Predicate<R> extraCondition; - public TypeFilter(ArgumentFilter<Object, Object> instanceOfLambda, Class<?>... type) { - this.type = type; - this.instanceOfLambda = instanceOfLambda; + @SuppressWarnings("rawtypes") + public TypeFilter(Class<R> type) { + assert type != null; + this.type1 = type; + this.type2 = null; + this.extraCondition = null; } - public Class<?>[] getType() { - return type; + @SuppressWarnings({"unchecked", "rawtypes"}) + public TypeFilter(Class<R> type, Predicate<R> extraCondition) { + assert type != null; + this.type1 = type; + this.type2 = null; + this.extraCondition = extraCondition; + } + + @SuppressWarnings("rawtypes") + public TypeFilter(Class<?> type1, Class<?> type2) { + assert type1 != null && type2 != null; + assert type1 != Object.class && type2 != Object.class; + this.type1 = type1; + this.type2 = type2; + this.extraCondition = null; + } + + public Class<?> getType1() { + return type1; + } + + public Class<?> getType2() { + return type2; + } + + public Predicate<R> getExtraCondition() { + return extraCondition; } @Override @@ -108,7 +139,26 @@ public abstract class Filter<T, R extends T> { return true; } + @SuppressWarnings("unchecked") public ArgumentFilter<Object, Object> getInstanceOfLambda() { + final ArgumentFilter<Object, Object> instanceOfLambda; + if (type2 == null) { + if (extraCondition == null) { + if (type1 == Object.class) { + instanceOfLambda = x -> true; + } else { + instanceOfLambda = x -> type1.isInstance(x); + } + } else { + if (type1 == Object.class) { + instanceOfLambda = x -> extraCondition.test((R) x); + } else { + instanceOfLambda = x -> type1.isInstance(x) && extraCondition.test((R) x); + } + } + } else { + instanceOfLambda = x -> type1.isInstance(x) || type2.isInstance(x); + } return instanceOfLambda; } @@ -126,6 +176,7 @@ public abstract class Filter<T, R extends T> { public ResultForArg resultForMissing() { return ResultForArg.FALSE; } + } /** @@ -363,6 +414,16 @@ public abstract class Filter<T, R extends T> { public <D> D accept(FilterVisitor<D> visitor) { return visitor.visit(this); } + + @Override + public ResultForArg resultForNull() { + if (subject instanceof VectorSize && ((VectorSize) subject).size == 0) { + return ResultForArg.TRUE; + } else { + return ResultForArg.FALSE; + } + } + } public abstract static class MatrixFilter<T extends RAbstractVector> extends Filter<T, T> { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineConfig.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineConfig.java index 3ecb60adfc..42146741e4 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineConfig.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineConfig.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -48,14 +48,17 @@ public class PipelineConfig { private final Mapper<? super RNull, ?> nullMapper; private final MessageData missingMsg; private final MessageData nullMsg; + private boolean valueForwarding; public PipelineConfig(String argumentName, MessageData defaultError, MessageData defaultWarning, Mapper<? super RMissing, ?> missingMapper, Mapper<? super RNull, ?> nullMapper, + boolean valueForwarding, MessageData missingMsg, MessageData nullMsg) { this.defaultError = defaultError; this.defaultWarning = defaultWarning; this.missingMapper = missingMapper; this.nullMapper = nullMapper; + this.valueForwarding = valueForwarding; this.missingMsg = missingMsg; this.nullMsg = nullMsg; this.argumentName = argumentName; @@ -93,6 +96,10 @@ public class PipelineConfig { return nullMsg; } + public boolean getValueForwarding() { + return valueForwarding; + } + public static ArgumentFilterFactory getFilterFactory() { return filterFactory; } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java index a253a907b2..27704a9d4b 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineStep.java @@ -49,6 +49,16 @@ public abstract class PipelineStep<T, R> { public abstract <D> D accept(PipelineStepVisitor<D> visitor); + public <D> D acceptPipeline(PipelineStepVisitor<D> visitor) { + PipelineStep<?, ?> curStep = this; + D result = null; + while (curStep != null) { + result = curStep.accept(visitor); + curStep = curStep.getNext(); + } + return result; + } + public interface PipelineStepVisitor<T> { T visit(FindFirstStep<?, ?> step); @@ -277,11 +287,13 @@ public abstract class PipelineStep<T, R> { private final Filter<?, ?> filter; private final PipelineStep<?, ?> trueBranch; private final PipelineStep<?, ?> falseBranch; + private final boolean returns; - public MapIfStep(Filter<?, ?> filter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch) { + public MapIfStep(Filter<?, ?> filter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch, boolean returns) { this.filter = filter; this.trueBranch = trueBranch; this.falseBranch = falseBranch; + this.returns = returns; } public Filter<?, ?> getFilter() { @@ -296,6 +308,10 @@ public abstract class PipelineStep<T, R> { return falseBranch; } + public boolean isReturns() { + return returns; + } + @Override public <D> D accept(PipelineStepVisitor<D> visitor) { return visitor.visit(this); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java index fdb073e077..98cb9ac462 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/PipelineToCastNode.java @@ -22,6 +22,8 @@ */ package com.oracle.truffle.r.nodes.builtin.casts; +import java.util.function.Supplier; + import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNode; import com.oracle.truffle.r.nodes.builtin.ArgumentFilter; import com.oracle.truffle.r.nodes.builtin.ArgumentFilter.ArgumentTypeFilter; @@ -61,6 +63,8 @@ import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapIfStep; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapStep; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.NotNAStep; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.PipelineStepVisitor; +import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardedValuesAnalyser; +import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingAnalysisResult; import com.oracle.truffle.r.nodes.unary.BypassNode; import com.oracle.truffle.r.nodes.unary.BypassNodeGen.BypassDoubleNodeGen; import com.oracle.truffle.r.nodes.unary.BypassNodeGen.BypassIntegerNodeGen; @@ -116,11 +120,26 @@ public final class PipelineToCastNode { public static CastNode convert(PipelineConfig config, PipelineStep<?, ?> firstStep, ArgumentFilterFactory filterFactory, ArgumentMapperFactory mapperFactory) { if (firstStep == null) { return BypassNode.create(config, null, mapperFactory, null); - } else { + } + + Supplier<CastNode> originalPipelineFactory = () -> { CastNodeFactory nodeFactory = new CastNodeFactory(filterFactory, mapperFactory, config.getDefaultDefaultMessage()); SinglePrimitiveOptimization singleOptVisitor = new SinglePrimitiveOptimization(nodeFactory); CastNode headNode = convert(firstStep, singleOptVisitor); return singleOptVisitor.createBypassNode(config, headNode, mapperFactory); + }; + + if (!config.getValueForwarding()) { + return originalPipelineFactory.get(); + } + + // TODO: the result of this analysis should be cached + ForwardedValuesAnalyser fwdAnalyser = new ForwardedValuesAnalyser(); + ForwardingAnalysisResult fwdAnalytics = fwdAnalyser.analyse(firstStep); + if (fwdAnalytics.isAnythingForwarded()) { + return ValueForwardingNodeGen.create(fwdAnalytics, originalPipelineFactory); + } else { + return originalPipelineFactory.get(); } } @@ -440,7 +459,7 @@ public final class PipelineToCastNode { CastNode trueCastNode = PipelineToCastNode.convert(step.getTrueBranch(), this); CastNode falseCastNode = PipelineToCastNode.convert(step.getFalseBranch(), this); return ConditionalMapNode.create(condition, trueCastNode, falseCastNode, ResultForArg.TRUE.equals(step.getFilter().resultForNull()), - ResultForArg.TRUE.equals(step.getFilter().resultForMissing())); + ResultForArg.TRUE.equals(step.getFilter().resultForMissing()), step.isReturns()); } private MessageData getDefaultErrorIfNull(MessageData message) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ValueForwardingNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ValueForwardingNode.java new file mode 100644 index 0000000000..623afa39a4 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ValueForwardingNode.java @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts; + +import java.util.function.Supplier; + +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingAnalysisResult; +import com.oracle.truffle.r.nodes.unary.CastNode; +import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.data.RComplex; +import com.oracle.truffle.r.runtime.data.RMissing; +import com.oracle.truffle.r.runtime.data.RNull; + +public abstract class ValueForwardingNode extends CastNode { + + @SuppressWarnings("unused") protected final ForwardingAnalysisResult forwardingResult; + private final Supplier<CastNode> pipelineFactory; + + protected ValueForwardingNode(ForwardingAnalysisResult forwardingResult, Supplier<CastNode> pipelineFactory) { + this.forwardingResult = forwardingResult; + this.pipelineFactory = pipelineFactory; + } + + @Specialization(guards = "forwardingResult.isNullForwarded()") + protected Object bypassNull(RNull x) { + return x; + } + + @Specialization(guards = "forwardingResult.isMissingForwarded()") + protected Object bypassMissing(RMissing x) { + return x; + } + + @Specialization(guards = "forwardingResult.isIntegerForwarded()") + protected int bypassInteger(int x) { + return x; + } + + @Specialization(guards = "forwardingResult.isLogicalForwarded()") + protected byte bypassLogical(byte x) { + return x; + } + + @Specialization(guards = "forwardingResult.isLogicalMappedToBoolean()") + protected boolean mapLogicalToBoolean(byte x) { + return RRuntime.fromLogical(x); + } + + @Specialization(guards = "forwardingResult.isDoubleForwarded()") + protected double bypassDouble(double x) { + return x; + } + + @Specialization(guards = "forwardingResult.isComplexForwarded()") + protected RComplex bypassComplex(RComplex x) { + return x; + } + + @Specialization(guards = "forwardingResult.isStringForwarded()") + protected String bypassString(String x) { + return x; + } + + protected CastNode createPipeline() { + return pipelineFactory.get(); + } + + @Specialization + protected Object executeOriginalPipeline(Object x, @Cached("createPipeline()") CastNode pipelineHeadNode) { + return pipelineHeadNode.execute(x); + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardedValuesAnalyser.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardedValuesAnalyser.java new file mode 100644 index 0000000000..ce9b228102 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardedValuesAnalyser.java @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts.analysis; + +import static com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingStatus.BLOCKED; +import static com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingStatus.FORWARDED; +import static com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingStatus.UNKNOWN; + +import com.oracle.truffle.r.nodes.builtin.casts.Filter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.AndFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.Dim; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.ElementAt; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.NATest; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.ScalarValue; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.StringLength; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter.VectorSize; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.DoubleFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.MatrixFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.MissingFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.NotFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.NullFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.OrFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.RTypeFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.TypeFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapByteToBoolean; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapDoubleToInt; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapToCharAt; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapToValue; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.AttributableCoercionStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.BoxPrimitiveStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.CoercionStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.DefaultErrorStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.DefaultWarningStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FilterStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FindFirstStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapIfStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.NotNAStep; +import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.PipelineStepVisitor; +import com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingStatus.Forwarded; +import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.data.RNull; + +public final class ForwardedValuesAnalyser implements PipelineStepVisitor<ForwardingAnalysisResult> { + + private ForwardingAnalysisResult result = new ForwardingAnalysisResult().forwardAll(); + private ForwardingAnalysisResult altResult = null; + + public ForwardingAnalysisResult analyse(PipelineStep<?, ?> firstStep) { + firstStep.acceptPipeline(this); + ForwardingAnalysisResult mainRes = result; + if (altResult != null) { + mainRes = mainRes.or(altResult); + } + return mainRes; + } + + private void addAlternativeResult(ForwardingAnalysisResult res) { + if (altResult == null) { + altResult = res; + } else { + altResult = altResult.or(res); + } + } + + @Override + public ForwardingAnalysisResult visit(FindFirstStep<?, ?> step) { + ForwardingAnalysisResult localRes = new ForwardingAnalysisResult().blockAll().setForwardedType(step.getElementClass(), FORWARDED); + if (step.getDefaultValue() == RNull.instance) { + // see CoercedPhaseBuilder.findFirstOrNull() + localRes.setNull(FORWARDED); + } + result = result.and(localRes); + return result; + } + + @Override + public ForwardingAnalysisResult visit(CoercionStep<?, ?> step) { + ForwardingAnalysisResult localRes = new ForwardingAnalysisResult().blockAll().setForwardedType(step.getType(), FORWARDED); + if (step.preserveNonVector) { + // i.e. preserve NULL and MISSING + localRes = localRes.setMissing(FORWARDED).setNull(FORWARDED); + } + result = result.and(localRes); + return result; + } + + @Override + public ForwardingAnalysisResult visit(MapStep<?, ?> step) { + ForwardingAnalysisResult localRes = step.getMapper().accept(new Mapper.MapperVisitor<ForwardingAnalysisResult>() { + + @Override + public ForwardingAnalysisResult visit(MapToValue<?, ?> mapper) { + return ForwardingAnalysisResult.INVALID; + } + + @Override + public ForwardingAnalysisResult visit(MapByteToBoolean mapper) { + return new ForwardingAnalysisResult().blockAll().setForwardedType(RType.Logical, new Forwarded(mapper)); + } + + @Override + public ForwardingAnalysisResult visit(MapDoubleToInt mapper) { + return ForwardingAnalysisResult.INVALID; + } + + @Override + public ForwardingAnalysisResult visit(MapToCharAt mapper) { + return ForwardingAnalysisResult.INVALID; + } + }); + result = result.and(localRes); + return result; + } + + @Override + public ForwardingAnalysisResult visit(MapIfStep<?, ?> mapIfStep) { + // analyze the true branch + ForwardingAnalysisResult trueBranchFwdRes; + PipelineStep<?, ?> trueBranchFirstStep = new FilterStep<>(mapIfStep.getFilter(), null, false); + ForwardedValuesAnalyser trueBranchFwdAnalyser = new ForwardedValuesAnalyser(); + if (mapIfStep.getTrueBranch() != null) { + trueBranchFirstStep.setNext(mapIfStep.getTrueBranch()); + } + trueBranchFwdRes = trueBranchFwdAnalyser.analyse(trueBranchFirstStep); + + // analyze the false branch + ForwardingAnalysisResult falseBranchFwdRes; + ForwardedValuesAnalyser falseBranchFwdAnalyser = new ForwardedValuesAnalyser(); + PipelineStep<?, ?> falseBranchFirstStep = new FilterStep<>(new NotFilter<>(mapIfStep.getFilter()), null, false); + if (mapIfStep.getFalseBranch() != null) { + falseBranchFirstStep.setNext(mapIfStep.getFalseBranch()); + } + falseBranchFwdRes = falseBranchFwdAnalyser.analyse(falseBranchFirstStep); + + if (mapIfStep.isReturns()) { + addAlternativeResult(result.and(trueBranchFwdRes)); + result = result.and(falseBranchFwdRes); + } else { + result = result.and(trueBranchFwdRes.or(falseBranchFwdRes)); + } + + return result; + } + + @Override + public ForwardingAnalysisResult visit(FilterStep<?, ?> step) { + class ForwardedValuesFilterVisitor implements Filter.FilterVisitor<ForwardingAnalysisResult> { + + @Override + public ForwardingAnalysisResult visit(TypeFilter<?, ?> filter) { + ForwardingAnalysisResult res = new ForwardingAnalysisResult().blockAll(); + if (filter.getExtraCondition() == null) { + res = res.setForwardedType(filter.getType1(), FORWARDED); + } else { + res = res.setForwardedType(filter.getType1(), UNKNOWN); + } + return res; + } + + @Override + public ForwardingAnalysisResult visit(RTypeFilter<?> filter) { + return new ForwardingAnalysisResult().blockAll().setForwardedType(filter.getType(), FORWARDED); + } + + @Override + public ForwardingAnalysisResult visit(CompareFilter<?> filter) { + return filter.getSubject().accept(new Filter.CompareFilter.SubjectVisitor<ForwardingAnalysisResult>() { + + @Override + public ForwardingAnalysisResult visit(ScalarValue scalarValue, byte operation) { + return new ForwardingAnalysisResult().blockAll().setForwardedType(scalarValue.type, UNKNOWN); + } + + @Override + public ForwardingAnalysisResult visit(NATest naTest, byte operation) { + return new ForwardingAnalysisResult().blockAll().setForwardedType(naTest.type, UNKNOWN); + } + + @Override + public ForwardingAnalysisResult visit(StringLength stringLength, byte operation) { + return new ForwardingAnalysisResult().blockAll().setForwardedType(RType.Character, UNKNOWN); + } + + @Override + public ForwardingAnalysisResult visit(VectorSize vectorSize, byte operation) { + if (vectorSize.size == 0 && operation == CompareFilter.EQ) { + return new ForwardingAnalysisResult().blockAll().setNull(FORWARDED); + } else if (vectorSize.size == 1 && (operation == CompareFilter.EQ || operation == CompareFilter.GT || operation == CompareFilter.GE)) { + return new ForwardingAnalysisResult().forwardAll().setNull(BLOCKED).setMissing(BLOCKED); + } else if (vectorSize.size > 1 && (operation == CompareFilter.EQ || operation == CompareFilter.GT || operation == CompareFilter.GE)) { + return new ForwardingAnalysisResult().blockAll(); + } else if (vectorSize.size > 1 && (operation == CompareFilter.LE || operation == CompareFilter.LT)) { + return new ForwardingAnalysisResult().forwardAll().setMissing(BLOCKED); + } else { + return new ForwardingAnalysisResult().unknownAll().setMissing(BLOCKED); + } + } + + @Override + public ForwardingAnalysisResult visit(ElementAt elementAt, byte operation) { + if (elementAt.index == 0) { + return new ForwardingAnalysisResult().blockAll().setForwardedType(elementAt.type, UNKNOWN); + } else { + return new ForwardingAnalysisResult().blockAll(); + } + } + + @Override + public ForwardingAnalysisResult visit(Dim dim, byte operation) { + return new ForwardingAnalysisResult().blockAll(); + } + }, filter.getOperation()); + } + + @Override + public ForwardingAnalysisResult visit(AndFilter<?, ?> filter) { + ForwardedValuesFilterVisitor leftVis = new ForwardedValuesFilterVisitor(); + ForwardedValuesFilterVisitor rightVis = new ForwardedValuesFilterVisitor(); + ForwardingAnalysisResult leftResult = filter.getLeft().accept(leftVis); + ForwardingAnalysisResult rightResult = filter.getRight().accept(rightVis); + return leftResult.and(rightResult); + } + + @Override + public ForwardingAnalysisResult visit(OrFilter<?> filter) { + ForwardedValuesFilterVisitor leftVis = new ForwardedValuesFilterVisitor(); + ForwardedValuesFilterVisitor rightVis = new ForwardedValuesFilterVisitor(); + ForwardingAnalysisResult leftResult = filter.getLeft().accept(leftVis); + ForwardingAnalysisResult rightResult = filter.getRight().accept(rightVis); + return leftResult.or(rightResult); + } + + @Override + public ForwardingAnalysisResult visit(NotFilter<?> filter) { + ForwardedValuesFilterVisitor vis = new ForwardedValuesFilterVisitor(); + return filter.getFilter().accept(vis).not(); + } + + @Override + public ForwardingAnalysisResult visit(MatrixFilter<?> filter) { + return new ForwardingAnalysisResult().forwardAll(); + } + + @Override + public ForwardingAnalysisResult visit(DoubleFilter filter) { + return new ForwardingAnalysisResult().blockAll().setForwardedType(RType.Double, UNKNOWN); + } + + @Override + public ForwardingAnalysisResult visit(NullFilter filter) { + return new ForwardingAnalysisResult().blockAll().setNull(FORWARDED); + } + + @Override + public ForwardingAnalysisResult visit(MissingFilter filter) { + return new ForwardingAnalysisResult().blockAll().setMissing(FORWARDED); + } + } + ForwardingAnalysisResult localRes = step.getFilter().accept(new ForwardedValuesFilterVisitor()); + result = result.and(localRes); + return result; + } + + @Override + public ForwardingAnalysisResult visit(NotNAStep<?> step) { + result = result.and(new ForwardingAnalysisResult().unknownAll().setNull(FORWARDED).setMissing(FORWARDED)); + return result; + } + + @Override + public ForwardingAnalysisResult visit(DefaultErrorStep<?> step) { + return result; + } + + @Override + public ForwardingAnalysisResult visit(DefaultWarningStep<?> step) { + return result; + } + + @Override + public ForwardingAnalysisResult visit(BoxPrimitiveStep<?> step) { + // TODO + result = ForwardingAnalysisResult.INVALID; + return result; + } + + @Override + public ForwardingAnalysisResult visit(AttributableCoercionStep<?> step) { + // TODO + result = ForwardingAnalysisResult.INVALID; + return result; + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingAnalysisResult.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingAnalysisResult.java new file mode 100644 index 0000000000..34c20fe18a --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingAnalysisResult.java @@ -0,0 +1,336 @@ +/* + * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts.analysis; + +import static com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingStatus.BLOCKED; +import static com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingStatus.FORWARDED; +import static com.oracle.truffle.r.nodes.builtin.casts.analysis.ForwardingStatus.UNKNOWN; + +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapByteToBoolean; +import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.data.RComplex; + +public final class ForwardingAnalysisResult { + + public final ForwardingStatus integerForwarded; + public final ForwardingStatus logicalForwarded; + public final ForwardingStatus doubleForwarded; + public final ForwardingStatus complexForwarded; + public final ForwardingStatus stringForwarded; + public final ForwardingStatus nullForwarded; + public final ForwardingStatus missingForwarded; + public final boolean invalid; + + static final ForwardingAnalysisResult INVALID = new ForwardingAnalysisResult(UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, true); + + ForwardingAnalysisResult() { + this(UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, UNKNOWN, false); + } + + private ForwardingAnalysisResult(ForwardingStatus integerForwarded, + ForwardingStatus logicalForwarded, + ForwardingStatus doubleForwarded, + ForwardingStatus complexForwarded, + ForwardingStatus stringForwarded, + ForwardingStatus nullForwarded, + ForwardingStatus missingForwarded, + boolean invalid) { + this.integerForwarded = integerForwarded; + this.logicalForwarded = logicalForwarded; + this.doubleForwarded = doubleForwarded; + this.complexForwarded = complexForwarded; + this.stringForwarded = stringForwarded; + this.nullForwarded = nullForwarded; + this.missingForwarded = missingForwarded; + this.invalid = invalid; + } + + public boolean isNullForwarded() { + return !invalid && nullForwarded.isForwarded(); + } + + public boolean isMissingForwarded() { + return !invalid && missingForwarded.isForwarded(); + } + + public boolean isIntegerForwarded() { + return !invalid && integerForwarded.isForwarded(); + } + + public boolean isLogicalForwarded() { + return !invalid && logicalForwarded.isForwarded(); + } + + public boolean isLogicalMappedToBoolean() { + return !invalid && logicalForwarded.mapper == MapByteToBoolean.INSTANCE; + } + + public boolean isDoubleForwarded() { + return !invalid && doubleForwarded.isForwarded(); + } + + public boolean isComplexForwarded() { + return !invalid && complexForwarded.isForwarded(); + } + + public boolean isStringForwarded() { + return !invalid && stringForwarded.isForwarded(); + } + + public boolean isAnythingForwarded() { + return isNullForwarded() || isMissingForwarded() || isIntegerForwarded() || isLogicalForwarded() || isDoubleForwarded() || isComplexForwarded() || isStringForwarded(); + } + + ForwardingAnalysisResult setForwardedType(Class<?> tp, ForwardingStatus status) { + if (invalid) { + return this; + } + + if (Integer.class == tp || int.class == tp) { + return new ForwardingAnalysisResult(status, + logicalForwarded, + doubleForwarded, + complexForwarded, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + } else if (Byte.class == tp || byte.class == tp) { + return new ForwardingAnalysisResult(integerForwarded, + status, + doubleForwarded, + complexForwarded, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + } else if (Double.class == tp || double.class == tp) { + return new ForwardingAnalysisResult(integerForwarded, + logicalForwarded, + status, + complexForwarded, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + } else if (RComplex.class == tp) { + return new ForwardingAnalysisResult(integerForwarded, + logicalForwarded, + doubleForwarded, + status, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + } else if (String.class == tp) { + return new ForwardingAnalysisResult(integerForwarded, + logicalForwarded, + doubleForwarded, + complexForwarded, + status, + nullForwarded, + missingForwarded, + invalid); + } else { + return this; + } + } + + ForwardingAnalysisResult setForwardedType(RType tp, ForwardingStatus status) { + if (invalid) { + return this; + } + + switch (tp) { + case Integer: + return new ForwardingAnalysisResult(status, + logicalForwarded, + doubleForwarded, + complexForwarded, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + case Logical: + return new ForwardingAnalysisResult(integerForwarded, + status, + doubleForwarded, + complexForwarded, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + case Double: + return new ForwardingAnalysisResult(integerForwarded, + logicalForwarded, + status, + complexForwarded, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + case Complex: + return new ForwardingAnalysisResult(integerForwarded, + logicalForwarded, + doubleForwarded, + status, + stringForwarded, + nullForwarded, + missingForwarded, + invalid); + case Character: + return new ForwardingAnalysisResult(integerForwarded, + logicalForwarded, + doubleForwarded, + complexForwarded, + status, + nullForwarded, + missingForwarded, + invalid); + default: + return this; + } + } + + ForwardingAnalysisResult and(ForwardingAnalysisResult other) { + if (this.invalid || other.invalid) { + return ForwardingAnalysisResult.INVALID; + } else { + return new ForwardingAnalysisResult(this.integerForwarded.and(other.integerForwarded), + this.logicalForwarded.and(other.logicalForwarded), + this.doubleForwarded.and(other.doubleForwarded), + this.complexForwarded.and(other.complexForwarded), + this.stringForwarded.and(other.stringForwarded), + this.nullForwarded.and(other.nullForwarded), + this.missingForwarded.and(other.missingForwarded), + false); + } + } + + ForwardingAnalysisResult or(ForwardingAnalysisResult other) { + if (this.invalid) { + return other; + } else if (other.invalid) { + return this; + } else { + return new ForwardingAnalysisResult(this.integerForwarded.or(other.integerForwarded), + this.logicalForwarded.or(other.logicalForwarded), + this.doubleForwarded.or(other.doubleForwarded), + this.complexForwarded.or(other.complexForwarded), + this.stringForwarded.or(other.stringForwarded), + this.nullForwarded.or(other.nullForwarded), + this.missingForwarded.or(other.missingForwarded), + false); + } + } + + ForwardingAnalysisResult not() { + if (this.invalid) { + return this; + } else { + return new ForwardingAnalysisResult(this.integerForwarded.not(), + this.logicalForwarded.not(), + this.doubleForwarded.not(), + this.complexForwarded.not(), + this.stringForwarded.not(), + this.nullForwarded.not(), + this.missingForwarded.not(), + false); + } + } + + ForwardingAnalysisResult forwardAll() { + if (this.invalid) { + return this; + } else { + return new ForwardingAnalysisResult(FORWARDED, + FORWARDED, + FORWARDED, + FORWARDED, + FORWARDED, + FORWARDED, + FORWARDED, + false); + } + } + + ForwardingAnalysisResult blockAll() { + if (this.invalid) { + return this; + } else { + return new ForwardingAnalysisResult(BLOCKED, + BLOCKED, + BLOCKED, + BLOCKED, + BLOCKED, + BLOCKED, + BLOCKED, + false); + } + } + + ForwardingAnalysisResult unknownAll() { + if (this.invalid) { + return this; + } else { + return new ForwardingAnalysisResult(UNKNOWN, + UNKNOWN, + UNKNOWN, + UNKNOWN, + UNKNOWN, + UNKNOWN, + UNKNOWN, + false); + } + } + + ForwardingAnalysisResult setNull(ForwardingStatus status) { + if (this.invalid) { + return this; + } else { + return new ForwardingAnalysisResult(this.integerForwarded, + this.logicalForwarded, + this.doubleForwarded, + this.complexForwarded, + this.stringForwarded, + status, + this.missingForwarded, + false); + } + } + + ForwardingAnalysisResult setMissing(ForwardingStatus status) { + if (this.invalid) { + return this; + } else { + return new ForwardingAnalysisResult(this.integerForwarded, + this.logicalForwarded, + this.doubleForwarded, + this.complexForwarded, + this.stringForwarded, + this.nullForwarded, + status, + false); + } + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingStatus.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingStatus.java new file mode 100644 index 0000000000..a59b7c0435 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/analysis/ForwardingStatus.java @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts.analysis; + +import com.oracle.truffle.r.nodes.builtin.casts.Mapper; +import com.oracle.truffle.r.runtime.RInternalError; + +public abstract class ForwardingStatus { + + public static final class Forwarded extends ForwardingStatus { + Forwarded(Mapper<?, ?> mapper) { + super((byte) 1, mapper); + } + } + + public static final ForwardingStatus BLOCKED = new ForwardingStatus((byte) 0, null) { + }; + public static final ForwardingStatus UNKNOWN = new ForwardingStatus((byte) -1, null) { + }; + public static final ForwardingStatus FORWARDED = new Forwarded(null); + + final Mapper<?, ?> mapper; + private final byte flag; + + protected ForwardingStatus(byte flag, Mapper<?, ?> mapper) { + this.flag = flag; + this.mapper = mapper; + } + + private static byte and(byte x1, byte x2) { + if (x1 < 0 && x2 < 0) { + return -1; + } else { + return (byte) (x1 * x2); + } + } + + private static byte or(byte x1, byte x2) { + if (x1 == 0 && x2 == 0) { + return 0; + } else { + return x1 + x2 >= 0 ? (byte) 1 : (byte) -1; + } + } + + private static byte not(byte x) { + if (x < 0) { + return -1; + } else { + return x == 0 ? (byte) 1 : (byte) 0; + } + } + + static ForwardingStatus fromFlag(byte flag) { + return fromFlag(flag, null); + } + + static ForwardingStatus fromFlag(byte flag, Mapper<?, ?> mapper) { + switch (flag) { + case -1: + return UNKNOWN; + case 0: + return BLOCKED; + case 1: + return mapper == null ? FORWARDED : new Forwarded(mapper); + default: + throw RInternalError.shouldNotReachHere(); + } + } + + ForwardingStatus and(ForwardingStatus other) { + if (this.mapper != null && other.mapper != null) { + // only one mapper per type is supported in this analysis + return UNKNOWN; + } + return fromFlag(and(this.flag, other.flag), this.mapper != null ? this.mapper : other.mapper); + } + + ForwardingStatus or(ForwardingStatus other) { + return fromFlag(or(this.flag, other.flag)); + } + + ForwardingStatus not() { + return fromFlag(not(this.flag)); + } + + public boolean isForwarded() { + return flag == (byte) 1 && mapper == null; + } + + public boolean isBlocked() { + return flag == (byte) 0; + } + + public boolean isUnknown() { + return flag == (byte) -1; + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CastNodeBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CastNodeBuilder.java index ded366fe19..aae3576a5a 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CastNodeBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CastNodeBuilder.java @@ -27,11 +27,11 @@ package com.oracle.truffle.r.nodes.builtin.casts.fluent; * be then used to cast anything, not only arguments. */ public final class CastNodeBuilder { - public static PreinitialPhaseBuilder<Object> newCastBuilder(String argName) { + public static PreinitialPhaseBuilder newCastBuilder(String argName) { return new PipelineBuilder(argName).fluent(); } - public static PreinitialPhaseBuilder<Object> newCastBuilder() { + public static PreinitialPhaseBuilder newCastBuilder() { return newCastBuilder(""); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CoercedPhaseBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CoercedPhaseBuilder.java index 016eef5e7a..5ced98be0f 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CoercedPhaseBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/CoercedPhaseBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -24,6 +24,7 @@ package com.oracle.truffle.r.nodes.builtin.casts.fluent; import com.oracle.truffle.r.nodes.builtin.casts.Filter; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.nodes.RBaseNode; @@ -89,6 +90,11 @@ public final class CoercedPhaseBuilder<T extends RAbstractVector, S> extends Arg return new HeadPhaseBuilder<>(pipelineBuilder()); } + public HeadPhaseBuilder<S> findFirstOrNull() { + pipelineBuilder().appendFindFirst(RNull.instance, elementClass, null, null, null); + return new HeadPhaseBuilder<>(pipelineBuilder()); + } + public CoercedPhaseBuilder<T, S> mustBe(Filter<? super T, ? extends T> argFilter, RBaseNode callObj, RError.Message message, Object... messageArgs) { pipelineBuilder().appendMustBeStep(argFilter, callObj, message, messageArgs); return this; diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/HeadPhaseBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/HeadPhaseBuilder.java index 7248195f59..6d0d72eb97 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/HeadPhaseBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/HeadPhaseBuilder.java @@ -46,23 +46,48 @@ public final class HeadPhaseBuilder<T> extends ArgCastBuilder<T, HeadPhaseBuilde return new HeadPhaseBuilder<>(pipelineBuilder()); } + public HeadPhaseBuilder<Object> returnIf(Filter<? super T, ?> argFilter) { + pipelineBuilder().appendMapIf(argFilter, (PipelineStep<?, ?>) null, (PipelineStep<?, ?>) null, true); + return new HeadPhaseBuilder<>(pipelineBuilder()); + } + public <S extends T, R> HeadPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper) { - pipelineBuilder().appendMapIf(argFilter, trueBranchMapper); + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, false); + return new HeadPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> HeadPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper) { + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, true); return new HeadPhaseBuilder<>(pipelineBuilder()); } public <S extends T, R> HeadPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper, Mapper<T, ?> falseBranchMapper) { - pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, falseBranchMapper); + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, falseBranchMapper, false); + return new HeadPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> HeadPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper, Mapper<T, ?> falseBranchMapper) { + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, falseBranchMapper, true); return new HeadPhaseBuilder<>(pipelineBuilder()); } public <S extends T, R> HeadPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, PipelineStep<S, ?> trueBranch) { - pipelineBuilder().appendMapIf(argFilter, trueBranch); + pipelineBuilder().appendMapIf(argFilter, trueBranch, false); + return new HeadPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> HeadPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, PipelineStep<S, ?> trueBranch) { + pipelineBuilder().appendMapIf(argFilter, trueBranch, true); return new HeadPhaseBuilder<>(pipelineBuilder()); } public <S extends T, R> HeadPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, PipelineStep<S, R> trueBranch, PipelineStep<T, ?> falseBranch) { - pipelineBuilder().appendMapIf(argFilter, trueBranch, falseBranch); + pipelineBuilder().appendMapIf(argFilter, trueBranch, falseBranch, false); + return new HeadPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> HeadPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, PipelineStep<S, R> trueBranch, PipelineStep<T, ?> falseBranch) { + pipelineBuilder().appendMapIf(argFilter, trueBranch, falseBranch, true); return new HeadPhaseBuilder<>(pipelineBuilder()); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/InitialPhaseBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/InitialPhaseBuilder.java index 9b0cbf6d29..68a83118a1 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/InitialPhaseBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/InitialPhaseBuilder.java @@ -98,23 +98,48 @@ public class InitialPhaseBuilder<T> extends ArgCastBuilder<T, InitialPhaseBuilde return new InitialPhaseBuilder<>(pipelineBuilder()); } + public InitialPhaseBuilder<Object> returnIf(Filter<? super T, ?> argFilter) { + pipelineBuilder().appendMapIf(argFilter, (PipelineStep<?, ?>) null, (PipelineStep<?, ?>) null, true); + return new InitialPhaseBuilder<>(pipelineBuilder()); + } + public <S extends T, R> InitialPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper) { - pipelineBuilder().appendMapIf(argFilter, trueBranchMapper); + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, false); + return new InitialPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> InitialPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper) { + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, true); return new InitialPhaseBuilder<>(pipelineBuilder()); } public <S extends T, R> InitialPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper, Mapper<T, ?> falseBranchMapper) { - pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, falseBranchMapper); + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, falseBranchMapper, false); + return new InitialPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> InitialPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, Mapper<S, R> trueBranchMapper, Mapper<T, ?> falseBranchMapper) { + pipelineBuilder().appendMapIf(argFilter, trueBranchMapper, falseBranchMapper, true); return new InitialPhaseBuilder<>(pipelineBuilder()); } public <S extends T, R> InitialPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, PipelineStep<?, ?> trueBranch) { - pipelineBuilder().appendMapIf(argFilter, trueBranch); + pipelineBuilder().appendMapIf(argFilter, trueBranch, false); + return new InitialPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> InitialPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, PipelineStep<?, ?> trueBranch) { + pipelineBuilder().appendMapIf(argFilter, trueBranch, true); return new InitialPhaseBuilder<>(pipelineBuilder()); } public <S extends T, R> InitialPhaseBuilder<Object> mapIf(Filter<? super T, S> argFilter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch) { - pipelineBuilder().appendMapIf(argFilter, trueBranch, falseBranch); + pipelineBuilder().appendMapIf(argFilter, trueBranch, falseBranch, false); + return new InitialPhaseBuilder<>(pipelineBuilder()); + } + + public <S extends T, R> InitialPhaseBuilder<Object> returnIf(Filter<? super T, S> argFilter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch) { + pipelineBuilder().appendMapIf(argFilter, trueBranch, falseBranch, true); return new InitialPhaseBuilder<>(pipelineBuilder()); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java index de6017065d..2efd6be469 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineBuilder.java @@ -59,8 +59,8 @@ public final class PipelineBuilder { return PipelineToCastNode.convert(pcb.build(), getFirstStep()); } - public PreinitialPhaseBuilder<Object> fluent() { - return new PreinitialPhaseBuilder<>(this); + public PreinitialPhaseBuilder fluent() { + return new PreinitialPhaseBuilder(this); } public void appendBoxPrimitive() { @@ -88,20 +88,20 @@ public final class PipelineBuilder { append(new NotNAStep<>(naReplacement, createMessage(callObj, message, messageArgs))); } - public void appendMapIf(Filter<?, ?> argFilter, Mapper<?, ?> trueBranchMapper) { - appendMapIf(argFilter, trueBranchMapper, null); + public void appendMapIf(Filter<?, ?> argFilter, Mapper<?, ?> trueBranchMapper, boolean returns) { + appendMapIf(argFilter, trueBranchMapper, null, returns); } - public void appendMapIf(Filter<?, ?> argFilter, Mapper<?, ?> trueBranchMapper, Mapper<?, ?> falseBranchMapper) { - appendMapIf(argFilter, new MapStep<>(trueBranchMapper), falseBranchMapper == null ? null : new MapStep<>(falseBranchMapper)); + public void appendMapIf(Filter<?, ?> argFilter, Mapper<?, ?> trueBranchMapper, Mapper<?, ?> falseBranchMapper, boolean returns) { + appendMapIf(argFilter, new MapStep<>(trueBranchMapper), falseBranchMapper == null ? null : new MapStep<>(falseBranchMapper), returns); } - public void appendMapIf(Filter<?, ?> argFilter, PipelineStep<?, ?> trueBranch) { - appendMapIf(argFilter, trueBranch, null); + public void appendMapIf(Filter<?, ?> argFilter, PipelineStep<?, ?> trueBranch, boolean returns) { + appendMapIf(argFilter, trueBranch, null, returns); } - public void appendMapIf(Filter<?, ?> argFilter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch) { - append(new MapIfStep<>(argFilter, trueBranch, falseBranch)); + public void appendMapIf(Filter<?, ?> argFilter, PipelineStep<?, ?> trueBranch, PipelineStep<?, ?> falseBranch, boolean returns) { + append(new MapIfStep<>(argFilter, trueBranch, falseBranch, returns)); } public void appendMap(Mapper<?, ?> mapFn) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineConfigBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineConfigBuilder.java index cb33445c4e..2492ae8922 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineConfigBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PipelineConfigBuilder.java @@ -22,14 +22,12 @@ */ package com.oracle.truffle.r.nodes.builtin.casts.fluent; -import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef; import com.oracle.truffle.r.nodes.builtin.casts.Mapper; import com.oracle.truffle.r.nodes.builtin.casts.MessageData; import com.oracle.truffle.r.nodes.builtin.casts.PipelineConfig; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; -import com.oracle.truffle.r.runtime.nodes.RBaseNode; /** * Provides fluent API for building the pipeline configuration {@link PipelineConfig}: default @@ -45,6 +43,7 @@ public final class PipelineConfigBuilder { private Mapper<? super RNull, ?> nullMapper; private MessageData missingMsg; private MessageData nullMsg; + private boolean valueForwarding = true; public PipelineConfigBuilder(String argumentName) { this.argumentName = argumentName; @@ -53,7 +52,7 @@ public final class PipelineConfigBuilder { } public PipelineConfig build() { - return new PipelineConfig(argumentName, defaultError, defaultWarning, missingMapper, nullMapper, missingMsg, nullMsg); + return new PipelineConfig(argumentName, defaultError, defaultWarning, missingMapper, nullMapper, valueForwarding, missingMsg, nullMsg); } void setDefaultError(MessageData defaultError) { @@ -64,55 +63,9 @@ public final class PipelineConfigBuilder { this.defaultWarning = defaultWarning; } - public PipelineConfigBuilder mustNotBeMissing(RBaseNode callObj, RError.Message errorMsg, Object... msgArgs) { - missingMapper = null; - missingMsg = new MessageData(callObj, errorMsg, msgArgs); + public PipelineConfigBuilder setValueForwarding(boolean flag) { + this.valueForwarding = flag; return this; } - public PipelineConfigBuilder mapMissing(Mapper<? super RMissing, ?> mapper) { - missingMapper = mapper; - missingMsg = null; - return this; - } - - public PipelineConfigBuilder mapMissing(Mapper<? super RMissing, ?> mapper, RBaseNode callObj, RError.Message warningMsg, Object... msgArgs) { - missingMapper = mapper; - missingMsg = new MessageData(callObj, warningMsg, msgArgs); - return this; - } - - public PipelineConfigBuilder allowMissing() { - return mapMissing(Predef.missingConstant()); - } - - public PipelineConfigBuilder allowMissing(RBaseNode callObj, RError.Message warningMsg, Object... msgArgs) { - return mapMissing(Predef.missingConstant(), callObj, warningMsg, msgArgs); - } - - public PipelineConfigBuilder mustNotBeNull(RBaseNode callObj, RError.Message errorMsg, Object... msgArgs) { - nullMapper = null; - nullMsg = new MessageData(callObj, errorMsg, msgArgs); - return this; - } - - public PipelineConfigBuilder mapNull(Mapper<? super RNull, ?> mapper) { - nullMapper = mapper; - nullMsg = null; - return this; - } - - public PipelineConfigBuilder mapNull(Mapper<? super RNull, ?> mapper, RBaseNode callObj, RError.Message warningMsg, Object... msgArgs) { - nullMapper = mapper; - nullMsg = new MessageData(callObj, warningMsg, msgArgs); - return this; - } - - public PipelineConfigBuilder allowNull() { - return mapNull(Predef.nullConstant()); - } - - public PipelineConfigBuilder allowNull(RBaseNode callObj, RError.Message warningMsg, Object... msgArgs) { - return mapNull(Predef.nullConstant(), callObj, warningMsg, msgArgs); - } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PreinitialPhaseBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PreinitialPhaseBuilder.java index 48740a3e13..8cd3486523 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PreinitialPhaseBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/fluent/PreinitialPhaseBuilder.java @@ -22,6 +22,9 @@ */ package com.oracle.truffle.r.nodes.builtin.casts.fluent; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; + import java.util.function.Consumer; import com.oracle.truffle.r.nodes.builtin.casts.Mapper; @@ -40,76 +43,76 @@ import com.oracle.truffle.r.runtime.nodes.RBaseNode; * {@link InitialPhaseBuilder} returns that type, so once the user steps outside the configuration, * there is no way to invoke configuration related methods defined here. */ -public final class PreinitialPhaseBuilder<T> extends InitialPhaseBuilder<T> { +public final class PreinitialPhaseBuilder extends InitialPhaseBuilder<Object> { PreinitialPhaseBuilder(PipelineBuilder pipelineBuilder) { super(pipelineBuilder); } - public PreinitialPhaseBuilder<T> conf(Consumer<PipelineConfigBuilder> cfgLambda) { + public PreinitialPhaseBuilder conf(Consumer<PipelineConfigBuilder> cfgLambda) { cfgLambda.accept(pipelineBuilder().getPipelineConfig()); return this; } - public InitialPhaseBuilder<T> allowNull() { - return conf(c -> c.allowNull()); + public InitialPhaseBuilder<Object> allowNull() { + return returnIf(nullValue()); } - public InitialPhaseBuilder<T> mustNotBeNull() { - return conf(c -> c.mustNotBeNull(null, null, (Object[]) null)); + public InitialPhaseBuilder<Object> mustNotBeNull() { + return mustBe(nullValue().not()); } - public InitialPhaseBuilder<T> mustNotBeNull(RError.Message errorMsg, Object... msgArgs) { - return conf(c -> c.mustNotBeNull(null, errorMsg, msgArgs)); + public InitialPhaseBuilder<Object> mustNotBeNull(RError.Message errorMsg, Object... msgArgs) { + return mustBe(nullValue().not(), null, errorMsg, msgArgs); } - public InitialPhaseBuilder<T> mustNotBeNull(RBaseNode callObj, RError.Message errorMsg, Object... msgArgs) { - return conf(c -> c.mustNotBeNull(callObj, errorMsg, msgArgs)); + public InitialPhaseBuilder<Object> mustNotBeNull(RBaseNode callObj, RError.Message errorMsg, Object... msgArgs) { + return mustBe(nullValue().not(), callObj, errorMsg, msgArgs); } - public InitialPhaseBuilder<T> mapNull(Mapper<? super RNull, ?> mapper) { - return conf(c -> c.mapNull(mapper)); + public InitialPhaseBuilder<Object> mapNull(Mapper<RNull, ?> mapper) { + return mapIf(nullValue(), mapper); } - public InitialPhaseBuilder<T> allowMissing() { - return conf(c -> c.allowMissing()); + public InitialPhaseBuilder<Object> allowMissing() { + return returnIf(missingValue()); } - public InitialPhaseBuilder<T> mustNotBeMissing() { - return conf(c -> c.mustNotBeMissing(null, null, (Object[]) null)); + public InitialPhaseBuilder<Object> mustNotBeMissing() { + return mustBe(missingValue().not()); } - public InitialPhaseBuilder<T> mustNotBeMissing(RError.Message errorMsg, Object... msgArgs) { - return conf(c -> c.mustNotBeMissing(null, errorMsg, msgArgs)); + public InitialPhaseBuilder<Object> mustNotBeMissing(RError.Message errorMsg, Object... msgArgs) { + return mustBe(missingValue().not(), null, errorMsg, msgArgs); } - public InitialPhaseBuilder<T> mustNotBeMissing(RBaseNode callObj, RError.Message errorMsg, Object... msgArgs) { - return conf(c -> c.mustNotBeMissing(callObj, errorMsg, msgArgs)); + public InitialPhaseBuilder<Object> mustNotBeMissing(RBaseNode callObj, RError.Message errorMsg, Object... msgArgs) { + return mustBe(missingValue().not(), callObj, errorMsg, msgArgs); } - public InitialPhaseBuilder<T> mapMissing(Mapper<? super RMissing, ?> mapper) { - return conf(c -> c.mapMissing(mapper)); + public InitialPhaseBuilder<Object> mapMissing(Mapper<RMissing, ?> mapper) { + return mapIf(missingValue(), mapper); } - public InitialPhaseBuilder<T> allowNullAndMissing() { - return conf(c -> c.allowMissing().allowNull()); + public InitialPhaseBuilder<Object> allowNullAndMissing() { + return returnIf(nullValue().or(missingValue())); } @Override - public PreinitialPhaseBuilder<T> defaultError(RBaseNode callObj, RError.Message message, Object... args) { + public PreinitialPhaseBuilder defaultError(RBaseNode callObj, RError.Message message, Object... args) { pipelineBuilder().getPipelineConfig().setDefaultError(new MessageData(callObj, message, args)); pipelineBuilder().appendDefaultErrorStep(callObj, message, args); return this; } @Override - public PreinitialPhaseBuilder<T> defaultError(Message message, Object... args) { + public PreinitialPhaseBuilder defaultError(Message message, Object... args) { defaultError(null, message, args); return this; } @Override - public PreinitialPhaseBuilder<T> defaultWarning(RBaseNode callObj, Message message, Object... args) { + public PreinitialPhaseBuilder defaultWarning(RBaseNode callObj, Message message, Object... args) { pipelineBuilder().getPipelineConfig().setDefaultWarning(new MessageData(callObj, message, args)); pipelineBuilder().appendDefaultWarningStep(callObj, message, args); return this; diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java index 2493bca82e..59ae7e9a2c 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -28,6 +28,7 @@ import com.oracle.truffle.r.nodes.builtin.ArgumentMapper; import com.oracle.truffle.r.nodes.builtin.casts.MessageData; import com.oracle.truffle.r.nodes.builtin.casts.PipelineConfig; import com.oracle.truffle.r.nodes.builtin.casts.PipelineToCastNode.ArgumentMapperFactory; +import com.oracle.truffle.r.nodes.unary.ConditionalMapNode.PipelineReturnException; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; @@ -102,7 +103,11 @@ public abstract class BypassNode extends CastNode { protected final Object executeAfterFindFirst(Object value) { if (afterFindFirst != null) { - return afterFindFirst.execute(value); + try { + return afterFindFirst.execute(value); + } catch (PipelineReturnException ret) { + return ret.getResult(); + } } else { return value; } @@ -160,7 +165,11 @@ public abstract class BypassNode extends CastNode { @Specialization(guards = "isNotHandled(x)") public Object handleOthers(Object x) { - return noHead ? x : wrappedHead.execute(x); + try { + return noHead ? x : wrappedHead.execute(x); + } catch (PipelineReturnException ret) { + return ret.getResult(); + } } protected boolean isNotHandled(Object x) { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java index c6c19f8f13..081b1bc06a 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java @@ -23,6 +23,7 @@ package com.oracle.truffle.r.nodes.unary; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.nodes.ControlFlowException; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.builtin.ArgumentFilter; import com.oracle.truffle.r.runtime.data.RMissing; @@ -34,23 +35,29 @@ public abstract class ConditionalMapNode extends CastNode { private final ConditionProfile conditionProfile = ConditionProfile.createBinaryProfile(); private final boolean resultForNull; private final boolean resultForMissing; + private final boolean returns; @Child private CastNode trueBranch; @Child private CastNode falseBranch; protected ConditionalMapNode(ArgumentFilter<?, ?> argFilter, CastNode trueBranch, CastNode falseBranch, boolean resultForNull, - boolean resultForMissing) { + boolean resultForMissing, boolean returns) { this.argFilter = argFilter; this.trueBranch = trueBranch; this.falseBranch = falseBranch; this.resultForNull = resultForNull; this.resultForMissing = resultForMissing; + this.returns = returns; } public static ConditionalMapNode create(ArgumentFilter<?, ?> argFilter, CastNode trueBranch, CastNode falseBranch, boolean resultForNull, - boolean resultForMissing) { - return ConditionalMapNodeGen.create(argFilter, trueBranch, falseBranch, resultForNull, resultForMissing); + boolean resultForMissing, boolean returns) { + return ConditionalMapNodeGen.create(argFilter, trueBranch, falseBranch, resultForNull, resultForMissing, returns); + } + + public boolean isReturns() { + return returns; } public ArgumentFilter<?, ?> getFilter() { @@ -68,7 +75,12 @@ public abstract class ConditionalMapNode extends CastNode { @Specialization protected Object executeNull(RNull x) { if (resultForNull) { - return trueBranch == null ? x : trueBranch.execute(x); + Object result = trueBranch == null ? x : trueBranch.execute(x); + if (returns) { + throw new PipelineReturnException(result); + } else { + return result; + } } else { return falseBranch == null ? x : falseBranch.execute(x); } @@ -77,7 +89,12 @@ public abstract class ConditionalMapNode extends CastNode { @Specialization protected Object executeMissing(RMissing x) { if (resultForMissing) { - return trueBranch == null ? x : trueBranch.execute(x); + Object result = trueBranch == null ? x : trueBranch.execute(x); + if (returns) { + throw new PipelineReturnException(result); + } else { + return result; + } } else { return falseBranch == null ? x : falseBranch.execute(x); } @@ -91,9 +108,28 @@ public abstract class ConditionalMapNode extends CastNode { @SuppressWarnings("unchecked") protected Object executeRest(Object x) { if (conditionProfile.profile(((ArgumentFilter<Object, Object>) argFilter).test(x))) { - return trueBranch == null ? x : trueBranch.execute(x); + Object result = trueBranch == null ? x : trueBranch.execute(x); + if (returns) { + throw new PipelineReturnException(result); + } else { + return result; + } } else { return falseBranch == null ? x : falseBranch.execute(x); } } + + @SuppressWarnings("serial") + public final class PipelineReturnException extends ControlFlowException { + + private final Object result; + + public PipelineReturnException(Object result) { + this.result = result; + } + + public Object getResult() { + return result; + } + } } -- GitLab