From bdb0cd8cfb159939877579cc95faf342684342ca Mon Sep 17 00:00:00 2001 From: Lukas Stadler <lukas.stadler@oracle.com> Date: Wed, 10 Aug 2016 17:33:57 +0200 Subject: [PATCH] simplify some argument cast nodes, remove special RNull / null mapping --- .../r/nodes/casts/PredefMappersSamplers.java | 2 +- ...er.java => ConditionalMapNodeSampler.java} | 4 +- ...GenSampler.java => FilterNodeSampler.java} | 5 +- ...odeGenSampler.java => MapNodeSampler.java} | 4 +- .../AbstractPredicateArgumentFilter.java | 4 +- .../truffle/r/nodes/builtin/CastBuilder.java | 62 ++++++++++--------- .../r/nodes/unary/ConditionalMapNode.java | 37 +++++------ .../truffle/r/nodes/unary/FilterNode.java | 26 ++++---- .../oracle/truffle/r/nodes/unary/MapNode.java | 20 +++--- 9 files changed, 81 insertions(+), 83 deletions(-) rename com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/{ConditionalMapNodeGenSampler.java => ConditionalMapNodeSampler.java} (95%) rename com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/{FilterNodeGenSampler.java => FilterNodeSampler.java} (94%) rename com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/{MapNodeGenSampler.java => MapNodeSampler.java} (93%) diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefMappersSamplers.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefMappersSamplers.java index 81709b92c6..52a28c37d4 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefMappersSamplers.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/PredefMappersSamplers.java @@ -97,7 +97,7 @@ public final class PredefMappersSamplers implements PredefMappers { @Override public T map(T arg) { - if (profile.profile(arg == RNull.instance || arg == null)) { + if (profile.profile(arg == RNull.instance)) { return defVal; } else { return arg; diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeSampler.java similarity index 95% rename from com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeGenSampler.java rename to com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeSampler.java index 8ccde29473..d5a6b2e65d 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNodeSampler.java @@ -28,13 +28,13 @@ import com.oracle.truffle.r.nodes.casts.Samples; import com.oracle.truffle.r.nodes.casts.TypeExpr; @SuppressWarnings({"rawtypes", "unchecked"}) -public class ConditionalMapNodeGenSampler extends CastNodeSampler<ConditionalMapNodeGen> { +public class ConditionalMapNodeSampler extends CastNodeSampler<ConditionalMapNode> { private final ArgumentFilterSampler argFilter; private final CastNodeSampler trueBranch; private final CastNodeSampler falseBranch; - public ConditionalMapNodeGenSampler(ConditionalMapNodeGen castNode) { + public ConditionalMapNodeSampler(ConditionalMapNode castNode) { super(castNode); argFilter = (ArgumentFilterSampler) castNode.getFilter(); trueBranch = createSampler(castNode.getTrueBranch()); diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeSampler.java similarity index 94% rename from com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeGenSampler.java rename to com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeSampler.java index bfdbdd71f7..3dcf5a9de9 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/FilterNodeSampler.java @@ -28,13 +28,13 @@ import com.oracle.truffle.r.nodes.casts.Samples; import com.oracle.truffle.r.nodes.casts.TypeExpr; @SuppressWarnings("rawtypes") -public class FilterNodeGenSampler extends CastNodeSampler<FilterNodeGen> { +public class FilterNodeSampler extends CastNodeSampler<FilterNode> { private final ArgumentFilterSampler filter; private final boolean isWarning; private final TypeExpr resType; - public FilterNodeGenSampler(FilterNodeGen castNode) { + public FilterNodeSampler(FilterNode castNode) { super(castNode); assert castNode.getFilter() instanceof ArgumentFilterSampler : "Check PredefFiltersSamplers is installed in Predef"; this.filter = (ArgumentFilterSampler) castNode.getFilter(); @@ -59,5 +59,4 @@ public class FilterNodeGenSampler extends CastNodeSampler<FilterNodeGen> { Samples<?> combined = samples.and(downStreamSamples); return combined; } - } diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeGenSampler.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeSampler.java similarity index 93% rename from com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeGenSampler.java rename to com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeSampler.java index 0c0844548b..1377ac3fc0 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeGenSampler.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/unary/MapNodeSampler.java @@ -28,11 +28,11 @@ import com.oracle.truffle.r.nodes.casts.Samples; import com.oracle.truffle.r.nodes.casts.TypeExpr; @SuppressWarnings({"rawtypes", "unchecked"}) -public class MapNodeGenSampler extends CastNodeSampler<MapNodeGen> { +public class MapNodeSampler extends CastNodeSampler<MapNode> { private final ArgumentMapperSampler mapFn; - public MapNodeGenSampler(MapNodeGen castNode) { + public MapNodeSampler(MapNode castNode) { super(castNode); assert castNode.getMapper() instanceof ArgumentMapperSampler; this.mapFn = (ArgumentMapperSampler) castNode.getMapper(); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/AbstractPredicateArgumentFilter.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/AbstractPredicateArgumentFilter.java index 43f0fe3d13..aabfb9e915 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/AbstractPredicateArgumentFilter.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/AbstractPredicateArgumentFilter.java @@ -42,10 +42,10 @@ public abstract class AbstractPredicateArgumentFilter<T, R extends T> implements @Override public boolean test(T arg) { - if (profile.profile(!isNullable && (arg == RNull.instance || arg == null))) { + if (profile.profile(!isNullable && (arg == RNull.instance))) { return false; } else { - return valuePredicate.test(arg == RNull.instance ? null : arg); + return valuePredicate.test(arg); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java index 564660b681..75be9cdb0f 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/CastBuilder.java @@ -44,13 +44,13 @@ import com.oracle.truffle.r.nodes.unary.CastStringNodeGen; import com.oracle.truffle.r.nodes.unary.CastToAttributableNodeGen; import com.oracle.truffle.r.nodes.unary.CastToVectorNodeGen; import com.oracle.truffle.r.nodes.unary.ChainedCastNode; -import com.oracle.truffle.r.nodes.unary.ConditionalMapNodeGen; -import com.oracle.truffle.r.nodes.unary.FilterNodeGen; +import com.oracle.truffle.r.nodes.unary.ConditionalMapNode; +import com.oracle.truffle.r.nodes.unary.FilterNode; import com.oracle.truffle.r.nodes.unary.FindFirstNodeGen; import com.oracle.truffle.r.nodes.unary.FirstBooleanNodeGen; import com.oracle.truffle.r.nodes.unary.FirstIntNode; import com.oracle.truffle.r.nodes.unary.FirstStringNode; -import com.oracle.truffle.r.nodes.unary.MapNodeGen; +import com.oracle.truffle.r.nodes.unary.MapNode; import com.oracle.truffle.r.nodes.unary.NonNANodeGen; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError.Message; @@ -343,7 +343,7 @@ public final class CastBuilder { @Override public <T, R extends T> TypePredicateArgumentFilter<T, R> nullValue() { - return new TypePredicateArgumentFilter<>(x -> x == RNull.instance || x == null, true); + return new TypePredicateArgumentFilter<>(x -> x == RNull.instance, true); } @Override @@ -483,6 +483,7 @@ public final class CastBuilder { @Override public <R> TypePredicateArgumentFilter<Object, R> instanceOf(Class<R> cls) { + assert cls != RNull.class : "cannot handle RNull.class with an isNullable=false filter"; return TypePredicateArgumentFilter.fromLambda(x -> cls.isInstance(x)); } @@ -608,7 +609,8 @@ public final class CastBuilder { @Override public T map(T arg) { - if (profile.profile(arg == RNull.instance || arg == null)) { + assert arg != null; + if (profile.profile(arg == RNull.instance)) { return defVal; } else { return arg; @@ -651,30 +653,30 @@ public final class CastBuilder { } public static <T> Function<ArgCastBuilder<T, ?>, CastNode> mustBe(ArgumentFilter<?, ?> argFilter, RBaseNode callObj, boolean boxPrimitives, RError.Message message, Object... messageArgs) { - return phaseBuilder -> FilterNodeGen.create(argFilter, false, callObj, message, messageArgs, boxPrimitives); + return phaseBuilder -> FilterNode.create(argFilter, false, callObj, message, messageArgs, boxPrimitives); } public static <T> Function<ArgCastBuilder<T, ?>, CastNode> mustBe(ArgumentFilter<?, ?> argFilter, boolean boxPrimitives) { - return phaseBuilder -> FilterNodeGen.create(argFilter, false, phaseBuilder.state().defaultError().callObj, phaseBuilder.state().defaultError().message, + return phaseBuilder -> FilterNode.create(argFilter, false, phaseBuilder.state().defaultError().callObj, phaseBuilder.state().defaultError().message, phaseBuilder.state().defaultError().args, boxPrimitives); } public static <T> Function<ArgCastBuilder<T, ?>, CastNode> shouldBe(ArgumentFilter<?, ?> argFilter, RBaseNode callObj, boolean boxPrimitives, RError.Message message, Object... messageArgs) { - return phaseBuilder -> FilterNodeGen.create(argFilter, true, callObj, message, messageArgs, boxPrimitives); + return phaseBuilder -> FilterNode.create(argFilter, true, callObj, message, messageArgs, boxPrimitives); } public static <T> Function<ArgCastBuilder<T, ?>, CastNode> shouldBe(ArgumentFilter<?, ?> argFilter, boolean boxPrimitives) { - return phaseBuilder -> FilterNodeGen.create(argFilter, true, phaseBuilder.state().defaultError().callObj, phaseBuilder.state().defaultError().message, + return phaseBuilder -> FilterNode.create(argFilter, true, phaseBuilder.state().defaultError().callObj, phaseBuilder.state().defaultError().message, phaseBuilder.state().defaultError().args, boxPrimitives); } public static <T> Function<ArgCastBuilder<T, ?>, CastNode> map(ArgumentMapper<?, ?> mapper) { - return phaseBuilder -> MapNodeGen.create(mapper); + return phaseBuilder -> MapNode.create(mapper); } public static <T> Function<ArgCastBuilder<T, ?>, CastNode> mapIf(ArgumentFilter<?, ?> filter, Function<ArgCastBuilder<T, ?>, CastNode> trueBranchFactory, Function<ArgCastBuilder<T, ?>, CastNode> falseBranchFactory) { - return phaseBuilder -> ConditionalMapNodeGen.create(filter, trueBranchFactory.apply(phaseBuilder), falseBranchFactory.apply(phaseBuilder)); + return phaseBuilder -> ConditionalMapNode.create(filter, trueBranchFactory.apply(phaseBuilder), falseBranchFactory.apply(phaseBuilder)); } public static <T> ChainBuilder<T> chain(CastNode firstCast) { @@ -1077,12 +1079,12 @@ public final class CastBuilder { } default THIS shouldBe(ArgumentFilter<? super T, ?> argFilter, RError.Message message, Object... messageArgs) { - state().castBuilder().insert(state().index(), FilterNodeGen.create(argFilter, true, null, message, messageArgs, state().boxPrimitives)); + state().castBuilder().insert(state().index(), FilterNode.create(argFilter, true, null, message, messageArgs, state().boxPrimitives)); return (THIS) this; } default THIS shouldBe(ArgumentFilter<? super T, ?> argFilter, RBaseNode callObj, RError.Message message, Object... messageArgs) { - state().castBuilder().insert(state().index(), FilterNodeGen.create(argFilter, true, callObj, message, messageArgs, state().boxPrimitives)); + state().castBuilder().insert(state().index(), FilterNode.create(argFilter, true, callObj, message, messageArgs, state().boxPrimitives)); return (THIS) this; } @@ -1213,7 +1215,7 @@ public final class CastBuilder { } void mustBe(ArgumentFilter<?, ?> argFilter, RBaseNode callObj, RError.Message message, Object... messageArgs) { - castBuilder().insert(index(), FilterNodeGen.create(argFilter, false, callObj, message, messageArgs, boxPrimitives)); + castBuilder().insert(index(), FilterNode.create(argFilter, false, callObj, message, messageArgs, boxPrimitives)); } void mustBe(ArgumentFilter<?, ?> argFilter) { @@ -1221,7 +1223,7 @@ public final class CastBuilder { } void shouldBe(ArgumentFilter<?, ?> argFilter, RBaseNode callObj, RError.Message message, Object... messageArgs) { - castBuilder().insert(index(), FilterNodeGen.create(argFilter, true, callObj, message, messageArgs, boxPrimitives)); + castBuilder().insert(index(), FilterNode.create(argFilter, true, callObj, message, messageArgs, boxPrimitives)); } void shouldBe(ArgumentFilter<?, ?> argFilter) { @@ -1291,26 +1293,26 @@ public final class CastBuilder { } default <S> InitialPhaseBuilder<S> map(ArgumentMapper<T, S> mapFn) { - state().castBuilder().insert(state().index(), MapNodeGen.create(mapFn)); + state().castBuilder().insert(state().index(), MapNode.create(mapFn)); return state().factory.newInitialPhaseBuilder(this); } @SuppressWarnings("overloads") default <S, R> InitialPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, ArgumentMapper<S, R> trueBranchMapper) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, MapNodeGen.create(trueBranchMapper), null)); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, MapNode.create(trueBranchMapper), null)); return state().factory.newInitialPhaseBuilder(this); } default <S, R> InitialPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, CastNode trueBranchNode) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, trueBranchNode, null)); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, trueBranchNode, null)); return state().factory.newInitialPhaseBuilder(this); } @SuppressWarnings("overloads") default <S, R> InitialPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, Function<ArgCastBuilder<T, ?>, CastNode> trueBranchNode) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, trueBranchNode.apply(this), null)); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, trueBranchNode.apply(this), null)); return state().factory.newInitialPhaseBuilder(this); } @@ -1319,14 +1321,14 @@ public final class CastBuilder { default <S, R> InitialPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, ArgumentMapper<S, R> trueBranchMapper, ArgumentMapper<T, T> falseBranchMapper) { state().castBuilder().insert( state().index(), - ConditionalMapNodeGen.create(argFilter, MapNodeGen.create(trueBranchMapper), - MapNodeGen.create(falseBranchMapper))); + ConditionalMapNode.create(argFilter, MapNode.create(trueBranchMapper), + MapNode.create(falseBranchMapper))); return state().factory.newInitialPhaseBuilder(this); } default <S, R> InitialPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, CastNode trueBranchNode, CastNode falseBranchNode) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, trueBranchNode, falseBranchNode)); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, trueBranchNode, falseBranchNode)); return state().factory.newInitialPhaseBuilder(this); } @@ -1334,7 +1336,7 @@ public final class CastBuilder { @SuppressWarnings("overloads") default <S, R> InitialPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, Function<ArgCastBuilder<T, ?>, CastNode> trueBranchNodeFactory, Function<ArgCastBuilder<T, ?>, CastNode> falseBranchNodeFactory) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, trueBranchNodeFactory.apply(this), falseBranchNodeFactory.apply(this))); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, trueBranchNodeFactory.apply(this), falseBranchNodeFactory.apply(this))); return state().factory.newInitialPhaseBuilder(this); } @@ -1494,18 +1496,18 @@ public final class CastBuilder { public interface HeadPhaseBuilder<T> extends ArgCastBuilder<T, HeadPhaseBuilder<T>> { default <S> HeadPhaseBuilder<S> map(ArgumentMapper<T, S> mapFn) { - state().castBuilder().insert(state().index(), MapNodeGen.create(mapFn)); + state().castBuilder().insert(state().index(), MapNode.create(mapFn)); return state().factory.newHeadPhaseBuilder(this); } default <S, R> HeadPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, ArgumentMapper<S, R> trueBranchMapper) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, MapNodeGen.create(trueBranchMapper), null)); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, MapNode.create(trueBranchMapper), null)); return state().factory.newHeadPhaseBuilder(this); } default <S, R> HeadPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, CastNode trueBranchNode) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, trueBranchNode, null)); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, trueBranchNode, null)); return state().factory.newHeadPhaseBuilder(this); } @@ -1514,8 +1516,8 @@ public final class CastBuilder { default <S, R> HeadPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, ArgumentMapper<S, R> trueBranchMapper, ArgumentMapper<T, T> falseBranchMapper) { state().castBuilder().insert( state().index(), - ConditionalMapNodeGen.create(argFilter, MapNodeGen.create(trueBranchMapper), - MapNodeGen.create(falseBranchMapper))); + ConditionalMapNode.create(argFilter, MapNode.create(trueBranchMapper), + MapNode.create(falseBranchMapper))); return state().factory.newHeadPhaseBuilder(this); } @@ -1523,13 +1525,13 @@ public final class CastBuilder { @SuppressWarnings("overloads") default <S, R> HeadPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, Function<ArgCastBuilder<T, ?>, CastNode> trueBranchNodeFactory, Function<ArgCastBuilder<T, ?>, CastNode> falseBranchNodeFactory) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, trueBranchNodeFactory.apply(this), falseBranchNodeFactory.apply(this))); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, trueBranchNodeFactory.apply(this), falseBranchNodeFactory.apply(this))); return state().factory.newHeadPhaseBuilder(this); } default <S, R> HeadPhaseBuilder<Object> mapIf(ArgumentFilter<? super T, S> argFilter, CastNode trueBranchNode, CastNode falseBranchNode) { - state().castBuilder().insert(state().index(), ConditionalMapNodeGen.create(argFilter, trueBranchNode, falseBranchNode)); + state().castBuilder().insert(state().index(), ConditionalMapNode.create(argFilter, trueBranchNode, falseBranchNode)); return state().factory.newHeadPhaseBuilder(this); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java index 4d4442785c..cbd81d715f 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/ConditionalMapNode.java @@ -22,23 +22,28 @@ */ package com.oracle.truffle.r.nodes.unary; -import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.builtin.ArgumentFilter; -@SuppressWarnings({"rawtypes", "unchecked"}) -public abstract class ConditionalMapNode extends CastNode { +public final class ConditionalMapNode extends CastNode { + + private final ArgumentFilter<?, ?> argFilter; + private final ConditionProfile conditionProfile = ConditionProfile.createBinaryProfile(); - private final ArgumentFilter argFilter; @Child private CastNode trueBranch; @Child private CastNode falseBranch; - protected ConditionalMapNode(ArgumentFilter<?, ?> argFilter, CastNode trueBranch, CastNode falseBranch) { + private ConditionalMapNode(ArgumentFilter<?, ?> argFilter, CastNode trueBranch, CastNode falseBranch) { this.argFilter = argFilter; this.trueBranch = trueBranch; this.falseBranch = falseBranch; } - public ArgumentFilter getFilter() { + public static ConditionalMapNode create(ArgumentFilter<?, ?> argFilter, CastNode trueBranch, CastNode falseBranch) { + return new ConditionalMapNode(argFilter, trueBranch, falseBranch); + } + + public ArgumentFilter<?, ?> getFilter() { return argFilter; } @@ -50,17 +55,13 @@ public abstract class ConditionalMapNode extends CastNode { return falseBranch; } - protected boolean doMap(Object x) { - return argFilter.test(x); - } - - @Specialization(guards = "doMap(x)") - protected Object map(Object x) { - return trueBranch == null ? x : trueBranch.execute(x); - } - - @Specialization(guards = "!doMap(x)") - protected Object noMap(Object x) { - return falseBranch == null ? x : falseBranch.execute(x); + @Override + @SuppressWarnings("unchecked") + public Object execute(Object x) { + if (conditionProfile.profile(((ArgumentFilter<Object, Object>) argFilter).test(x))) { + return trueBranch == null ? x : trueBranch.execute(x); + } else { + return falseBranch == null ? x : falseBranch.execute(x); + } } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FilterNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FilterNode.java index 9d3d5d5c40..fd7e9a8a57 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FilterNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/FilterNode.java @@ -22,9 +22,8 @@ */ package com.oracle.truffle.r.nodes.unary; -import com.oracle.truffle.api.dsl.Fallback; -import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNode; import com.oracle.truffle.r.nodes.binary.BoxPrimitiveNodeGen; import com.oracle.truffle.r.nodes.builtin.ArgumentFilter; @@ -32,7 +31,7 @@ import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.nodes.RBaseNode; @SuppressWarnings({"rawtypes", "unchecked"}) -public abstract class FilterNode extends CastNode { +public final class FilterNode extends CastNode { private final ArgumentFilter filter; private final RError.Message message; @@ -42,10 +41,11 @@ public abstract class FilterNode extends CastNode { private final boolean isWarning; private final BranchProfile warningProfile = BranchProfile.create(); + private final ConditionProfile conditionProfile = ConditionProfile.createBinaryProfile(); @Child private BoxPrimitiveNode boxPrimitiveNode = BoxPrimitiveNodeGen.create(); - protected FilterNode(ArgumentFilter<?, ?> filter, boolean isWarning, RBaseNode callObj, RError.Message message, Object[] messageArgs, boolean boxPrimitives) { + private FilterNode(ArgumentFilter<?, ?> filter, boolean isWarning, RBaseNode callObj, RError.Message message, Object[] messageArgs, boolean boxPrimitives) { this.filter = filter; this.isWarning = isWarning; this.callObj = callObj == null ? this : callObj; @@ -54,6 +54,10 @@ public abstract class FilterNode extends CastNode { this.boxPrimitives = boxPrimitives; } + public static FilterNode create(ArgumentFilter<?, ?> filter, boolean isWarning, RBaseNode callObj, RError.Message message, Object[] messageArgs, boolean boxPrimitives) { + return new FilterNode(filter, isWarning, callObj, message, messageArgs, boxPrimitives); + } + public ArgumentFilter getFilter() { return filter; } @@ -73,14 +77,11 @@ public abstract class FilterNode extends CastNode { } } - @Specialization(guards = "evalCondition(x)") - protected Object onTrue(Object x) { - return x; - } - - @Fallback - protected Object onFalse(Object x) { - handleMessage(x); + @Override + public Object execute(Object x) { + if (!conditionProfile.profile(evalCondition(x))) { + handleMessage(x); + } return x; } @@ -88,5 +89,4 @@ public abstract class FilterNode extends CastNode { Object y = boxPrimitives ? boxPrimitiveNode.execute(x) : x; return filter.test(y); } - } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/MapNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/MapNode.java index 7b833c222c..e914f791c7 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/MapNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/MapNode.java @@ -22,31 +22,27 @@ */ package com.oracle.truffle.r.nodes.unary; -import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.builtin.ArgumentMapper; -import com.oracle.truffle.r.runtime.data.RNull; @SuppressWarnings({"rawtypes", "unchecked"}) -public abstract class MapNode extends CastNode { +public final class MapNode extends CastNode { private final ArgumentMapper mapFn; - protected MapNode(ArgumentMapper<?, ?> mapFn) { + private MapNode(ArgumentMapper<?, ?> mapFn) { this.mapFn = mapFn; } - public ArgumentMapper getMapper() { - return mapFn; + public static MapNode create(ArgumentMapper<?, ?> mapFn) { + return new MapNode(mapFn); } - @Specialization - protected Object mapNull(RNull x) { - Object res = mapFn.map(null); - return res == null ? x : res; + public ArgumentMapper getMapper() { + return mapFn; } - @Specialization - protected Object map(Object x) { + @Override + public Object execute(Object x) { return mapFn.map(x); } } -- GitLab