From c2f6d4040792a25ad92ae03282bc48ba8693ed78 Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Mon, 5 Sep 2016 16:52:35 +0200 Subject: [PATCH] BypassNode can bypass some atomic types This means that e.g. int will go directly to the builtin and not through the cast pipeline in the following example: arg("x").mustBe(numericValue()).asIntegerVector().findFirst(). Initial structure of intermediate representation of cast pipeline and its conversion to CastNodes. --- .../truffle/r/nodes/casts/TestCasts.java | 21 ++ .../r/nodes/builtin/casts/CastStep.java | 158 ++++++++++++ .../builtin/casts/CastStepToCastNode.java | 233 ++++++++++++++++++ .../truffle/r/nodes/builtin/casts/Filter.java | 208 ++++++++++++++++ .../truffle/r/nodes/builtin/casts/Mapper.java | 91 +++++++ .../truffle/r/nodes/unary/BypassNode.java | 177 ++++++++++--- 6 files changed, 852 insertions(+), 36 deletions(-) create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStep.java create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStepToCastNode.java create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java create mode 100644 com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Mapper.java diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/TestCasts.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/TestCasts.java index 27d7ca784a..df43bb77b4 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/TestCasts.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/TestCasts.java @@ -277,6 +277,27 @@ public class TestCasts extends TestBase { testCompilation(new Object[]{1}, new Root("MustBeWithConstant", 1)); } + @Test + public void optimizedBypass() { + class Root extends TestRootNode<CastNode> { + + private final Object constant; + + protected Root(String name, Object constant) { + super(name, new CastBuilder().arg(0).mustBe(integerValue()).asIntegerVector().findFirst().builder().getCasts()[0]); + this.constant = constant; + } + + @Override + protected Object execute(VirtualFrame frame, Object value) { + int result = (int) node.execute(constant); + CompilerAsserts.compilationConstant(result); + return null; + } + } + testCompilation(new Object[]{1}, new Root("optimizeBypass1", 1)); + } + @Test public void testConditionalMapChainWithConstant() { class Root extends TestRootNode<CastNode> { diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStep.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStep.java new file mode 100644 index 0000000000..12b394fbc5 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStep.java @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts; + +import com.oracle.truffle.r.nodes.builtin.CastBuilder.PipelineConfigBuilder; +import com.oracle.truffle.r.runtime.RType; + +/** + * Represents a single step in the cast pipeline. + */ +public abstract class CastStep { + + private CastStep next; + + public final CastStep getNext() { + return next; + } + + public final void setNext(CastStep next) { + this.next = next; + } + + public abstract <T> T accept(CastStepVisitor<T> visitor); + + public interface CastStepVisitor<T> { + T visit(PipelineConfStep step); + + T visit(FindFirstStep step); + + T visit(AsVectorStep step); + + T visit(MapStep step); + + T visit(MapIfStep step); + + T visit(FilterStep step); + + T visit(NotNAStep step); + } + + public static class PipelineConfStep extends CastStep { + private final PipelineConfigBuilder pcb; + // TODO??: just remember from the builder: boolean acceptNull, boolean acceptMissing, + // defaultError?, ... + + public PipelineConfStep(PipelineConfigBuilder pcb) { + this.pcb = pcb; + } + + public PipelineConfigBuilder getConfigBuilder() { + return pcb; + } + + @Override + public <T> T accept(CastStepVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static class NotNAStep extends CastStep { + @Override + public <T> T accept(CastStepVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static class FindFirstStep extends CastStep { + private final Object defaultValue; + private final Class<?> elementClass; + + public FindFirstStep(Object defaultValue, Class<?> elementClass) { + this.defaultValue = defaultValue; + this.elementClass = elementClass; + } + + public Object getDefaultValue() { + return defaultValue; + } + + public Class<?> getElementClass() { + return elementClass; + } + + @Override + public <T> T accept(CastStepVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static class AsVectorStep extends CastStep { + private final RType type; + + public AsVectorStep(RType type) { + assert type.isVector() && type != RType.List : "AsVectorStep supports only vector types minus list."; + this.type = type; + } + + public RType getType() { + return type; + } + + @Override + public <T> T accept(CastStepVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static class MapStep extends CastStep { + @Override + public <T> T accept(CastStepVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static class MapIfStep extends CastStep { + @Override + public <T> T accept(CastStepVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static class FilterStep extends CastStep { + private final Filter filter; + + public FilterStep(Filter filter) { + this.filter = filter; + } + + public Filter getFilter() { + return filter; + } + + @Override + public <T> T accept(CastStepVisitor<T> visitor) { + return visitor.visit(this); + } + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStepToCastNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStepToCastNode.java new file mode 100644 index 0000000000..4077e66bef --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/CastStepToCastNode.java @@ -0,0 +1,233 @@ +/* + * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts; + +import com.oracle.truffle.r.nodes.builtin.ArgumentFilter; +import com.oracle.truffle.r.nodes.builtin.CastBuilder.PipelineConfigBuilder; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.AsVectorStep; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.CastStepVisitor; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.FilterStep; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.FindFirstStep; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.MapIfStep; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.MapStep; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.NotNAStep; +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.PipelineConfStep; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.AndFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.CompareFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.FilterVisitor; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.NotFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.NumericFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.OrFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.RTypeFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Filter.TypeFilter; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapByteToBoolean; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapDoubleToInt; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapToCharAt; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapToValue; +import com.oracle.truffle.r.nodes.builtin.casts.Mapper.MapperVisitor; +import com.oracle.truffle.r.nodes.unary.BypassNode; +import com.oracle.truffle.r.nodes.unary.CastNode; +import com.oracle.truffle.r.nodes.unary.ChainedCastNode; +import com.oracle.truffle.r.nodes.unary.FilterNode; +import com.oracle.truffle.r.nodes.unary.FindFirstNodeGen; +import com.oracle.truffle.r.nodes.unary.MapNode; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; +import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.RType; +import com.oracle.truffle.r.runtime.data.RDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; + +/** + * Converts given pipeline into corresponding cast nodes chain. + */ +public final class CastStepToCastNode { + + public static CastNode convert(PipelineConfStep firstStep) { + PipelineConfigBuilder configBuilder = firstStep.getConfigBuilder(); + + CastNodeFactory nodeFactory = new CastNodeFactory(null, null, null, true); // TODO: default + // error instead + // of nulls + CastNode prevCastNode = null; + CastStep currCastStep = firstStep.getNext(); + while (currCastStep != null) { + CastNode node = nodeFactory.create(currCastStep); + if (prevCastNode == null) { + prevCastNode = node; + } else { + CastNode finalPrevCastNode = prevCastNode; + prevCastNode = new ChainedCastNode(() -> node, () -> finalPrevCastNode); + } + + currCastStep = currCastStep.getNext(); + } + return BypassNode.create(configBuilder, prevCastNode); + } + + private static final class CastNodeFactory implements CastStepVisitor<CastNode> { + private final RBaseNode defaultCallObj; + private final RError.Message defaultMessage; + private final Object[] defaultMessageArgs; + private final boolean boxPrimitives; + + public CastNodeFactory(RBaseNode defaultCallObj, Message defaultMessage, Object[] defaultMessageArgs, boolean boxPrimitives) { + this.defaultCallObj = defaultCallObj; + this.defaultMessage = defaultMessage; + this.defaultMessageArgs = defaultMessageArgs; + this.boxPrimitives = boxPrimitives; + } + + public CastNode create(CastStep step) { + return step.accept(this); + } + + @Override + public CastNode visit(PipelineConfStep step) { + throw RInternalError.shouldNotReachHere("There can be only one PipelineConfStep " + + "in pipeline as the first node and it should have been handled by the convert method."); + } + + @Override + public CastNode visit(FindFirstStep step) { + return FindFirstNodeGen.create(step.getElementClass(), step.getDefaultValue()); + } + + @Override + public CastNode visit(FilterStep step) { + ArgumentFilter<Object, Boolean> filter = ArgumentFilterFactory.create(step.getFilter()); + // TODO: check error in step and use it instead of the default one + return FilterNode.create(filter, /* TODO: isWarning?? */false, defaultCallObj, defaultMessage, defaultMessageArgs, boxPrimitives); + } + + @Override + public CastNode visit(NotNAStep step) { + return null; + } + + @Override + public CastNode visit(AsVectorStep step) { + return null; + } + + @Override + public CastNode visit(MapStep step) { + return null; + } + + @Override + public CastNode visit(MapIfStep step) { + return null; + } + } + + private static final class ArgumentFilterFactory implements FilterVisitor<ArgumentFilter<Object, Boolean>> { + + private static final ArgumentFilterFactory INSTANCE = new ArgumentFilterFactory(); + + private ArgumentFilterFactory() { + // singleton + } + + public static ArgumentFilter<Object, Boolean> create(Filter filter) { + return filter.accept(INSTANCE); + } + + @Override + public ArgumentFilter<Object, Boolean> visit(TypeFilter filter) { + return filter.getInstanceOfLambda(); + } + + @Override + public ArgumentFilter<Object, Boolean> visit(RTypeFilter filter) { + if (filter.getType() == RType.Integer) { + return x -> x instanceof Integer || x instanceof RAbstractIntVector; + } else if (filter.getType() == RType.Double) { + return x -> x instanceof Double || x instanceof RDoubleVector; + } else { + throw RInternalError.unimplemented("TODO: more types here"); + } + } + + @Override + public ArgumentFilter<Object, Boolean> visit(NumericFilter filter) { + return x -> x instanceof Integer || x instanceof RAbstractIntVector || x instanceof Double || x instanceof RAbstractDoubleVector || x instanceof Byte || + x instanceof RAbstractLogicalVector; + } + + @Override + public ArgumentFilter<Object, Boolean> visit(CompareFilter filter) { + return null; + } + + @Override + public ArgumentFilter<Object, Boolean> visit(AndFilter filter) { + ArgumentFilter<Object, Boolean> leftFilter = filter.getLeft().accept(this); + ArgumentFilter<Object, Boolean> rightFilter = filter.getRight().accept(this); + // TODO: create and filter... + return null; + } + + @Override + public ArgumentFilter<Object, Boolean> visit(OrFilter filter) { + ArgumentFilter<Object, Boolean> leftFilter = filter.getLeft().accept(this); + ArgumentFilter<Object, Boolean> rightFilter = filter.getRight().accept(this); + // TODO: create or filter... + return null; + } + + @Override + public ArgumentFilter<Object, Boolean> visit(NotFilter filter) { + ArgumentFilter<Object, Boolean> toNegate = filter.accept(this); + // TODO: create not filter + return null; + } + } + + private static final class MapperNodeFactory implements MapperVisitor<MapNode> { + + @Override + public MapNode visit(MapToValue mapper) { + final Object value = mapper.getValue(); + return MapNode.create(x -> value); + } + + @Override + public MapNode visit(MapByteToBoolean mapper) { + return null; + } + + @Override + public MapNode visit(MapDoubleToInt mapper) { + return null; + } + + @Override + public MapNode visit(MapToCharAt mapper) { + return null; + } + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java new file mode 100644 index 0000000000..7927c4fc2d --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Filter.java @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts; + +import static com.oracle.truffle.r.nodes.builtin.casts.CastStep.FilterStep; +import static com.oracle.truffle.r.nodes.builtin.casts.CastStep.MapStep; + +import com.oracle.truffle.r.nodes.builtin.ArgumentFilter; +import com.oracle.truffle.r.runtime.RType; + +/** + * Represents filters that can be used in {@link FilterStep} and as condition in {@link MapStep}. + */ +public abstract class Filter { + + public abstract <T> T accept(FilterVisitor<T> visitor); + + public interface FilterVisitor<T> { + T visit(TypeFilter filter); + + T visit(RTypeFilter filter); + + T visit(CompareFilter filter); + + T visit(AndFilter filter); + + T visit(OrFilter filter); + + T visit(NotFilter filter); + + T visit(NumericFilter filter); + } + + /** + * Filters specific Java class. + */ + public static final class TypeFilter extends Filter { + private final Class<?> type; + private final ArgumentFilter<Object, Boolean> instanceOfLambda; + + public TypeFilter(Class<?> type, ArgumentFilter<Object, Boolean> instanceOfLambda) { + this.type = type; + this.instanceOfLambda = instanceOfLambda; + } + + public Class<?> getType() { + return type; + } + + /** + * This is lambda in form of 'x instanceof type' in order to avoid reflective + * Class.instanceOf call. + */ + public ArgumentFilter<Object, Boolean> getInstanceOfLambda() { + return instanceOfLambda; + } + + @Override + public <T> T accept(FilterVisitor<T> visitor) { + return visitor.visit(this); + } + } + + /** + * Filters specified set of type in R sense, supports only vector types minus list. + */ + public static final class RTypeFilter extends Filter { + private final RType type; + + public RTypeFilter(RType type) { + assert type.isVector() && type != RType.List : "RTypeFilter supports only vector types minus list."; + this.type = type; + } + + public RType getType() { + return type; + } + + @Override + public <T> T accept(FilterVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static final class NumericFilter extends Filter { + @Override + public <T> T accept(FilterVisitor<T> visitor) { + return visitor.visit(this); + } + } + + /** + * Compares the real value against given value using given operation. Use the constants defined + * within this class for the operation. + */ + public static final class CompareFilter extends Filter { + public static final byte EQ = 0; + public static final byte GT = 1; + public static final byte LT = 2; + public static final byte GE = 3; + public static final byte LE = 4; + + private final byte operation; + private final Object value; + + public CompareFilter(byte operation, Object value) { + assert operation <= LE : "wrong operation value"; + this.operation = operation; + this.value = value; + } + + public Object getValue() { + return value; + } + + public byte getOperation() { + return operation; + } + + @Override + public <T> T accept(FilterVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static final class AndFilter extends Filter { + private final Filter left; + private final Filter right; + + public AndFilter(Filter left, Filter right) { + this.left = left; + this.right = right; + } + + public Filter getLeft() { + return left; + } + + public Filter getRight() { + return right; + } + + @Override + public <T> T accept(FilterVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static final class OrFilter extends Filter { + private final Filter left; + private final Filter right; + + public OrFilter(Filter left, Filter right) { + this.left = left; + this.right = right; + } + + public Filter getLeft() { + return left; + } + + public Filter getRight() { + return right; + } + + @Override + public <T> T accept(FilterVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static final class NotFilter extends Filter { + private final Filter filter; + + public NotFilter(Filter filter) { + this.filter = filter; + } + + public Filter getFilter() { + return filter; + } + + @Override + public <T> T accept(FilterVisitor<T> visitor) { + return visitor.visit(this); + } + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Mapper.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Mapper.java new file mode 100644 index 0000000000..12c0321318 --- /dev/null +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/Mapper.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2013, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.casts; + +import com.oracle.truffle.r.nodes.builtin.casts.CastStep.MapStep; + +/** + * Represents mapping used in {@link MapStep}. + */ +public abstract class Mapper { + + public abstract <T> T accept(MapperVisitor<T> visitor); + + public interface MapperVisitor<T> { + T visit(MapToValue mapper); + + T visit(MapByteToBoolean mapper); + + T visit(MapDoubleToInt mapper); + + T visit(MapToCharAt mapper); + } + + public static final class MapToValue extends Mapper { + private final Object value; + + public MapToValue(Object value) { + this.value = value; + } + + public Object getValue() { + return value; + } + + @Override + public <T> T accept(MapperVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static final class MapByteToBoolean extends Mapper { + @Override + public <T> T accept(MapperVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static final class MapDoubleToInt extends Mapper { + @Override + public <T> T accept(MapperVisitor<T> visitor) { + return visitor.visit(this); + } + } + + public static final class MapToCharAt extends Mapper { + private final int index; + + public MapToCharAt(int index) { + this.index = index; + } + + public int getIndex() { + return index; + } + + @Override + public <T> T accept(MapperVisitor<T> visitor) { + return visitor.visit(this); + } + } +} diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java index 61293070ac..774143a1bf 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/unary/BypassNode.java @@ -24,12 +24,31 @@ package com.oracle.truffle.r.nodes.unary; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.r.nodes.builtin.ArgumentMapper; import com.oracle.truffle.r.nodes.builtin.CastBuilder.DefaultError; import com.oracle.truffle.r.nodes.builtin.CastBuilder.PipelineConfigBuilder; +import com.oracle.truffle.r.nodes.unary.BypassNodeGen.BypassDoubleNodeGen; +import com.oracle.truffle.r.nodes.unary.BypassNodeGen.BypassIntegerNodeGen; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; +/** + * The node wraps cast pipeline and handles {@code RNull} and {@code RMissing} according to + * {@link PipelineConfigBuilder}. If the pipeline contains {@code findFirst} step and RNull/RMissing + * is allowed in the config, then RNull/RMissing is routed to the logic of {@code findFirst}, i.e. + * without defaultValue, it gives error, with defaultValue, returns the defaultValue. Any mappers + * after findFirst will be applied too. + * + * The factory method {@link #create(PipelineConfigBuilder, CastNode)} creates either directly + * {@link BypassNode} or one of its protected subclasses that can also bypass single atomic values + * (these are also, like RNull/RMissing, routed to findFirst and any consecutive mappers). The + * subclasses correspond to subclasses of {@link CastBaseNode}. The idea is that if the pipeline + * until 'findFirst' contains only one 'asXYVector' step and no 'map' or 'mapIf', then we can assume + * that any atomic value of type XY can be passed directly to the 'findFirst' step (although mustBe + * could disallow values of type XY, we assume that this will not happen when asXYVector is used, + * and any checks of the value will be done after findFirst). + */ @SuppressWarnings({"rawtypes", "unchecked"}) public abstract class BypassNode extends CastNode { @@ -42,11 +61,23 @@ public abstract class BypassNode extends CastNode { private final ArgumentMapper missingMapFn; private final boolean noHead; + /** + * This is the cast pipeline itself. + */ @Child private CastNode wrappedHead; - @Child private CastNode directFindFirstNode; - private final boolean useDirectFindFirstNode; - protected BypassNode(PipelineConfigBuilder pcb, CastNode wrappedHead) { + /** + * If there is a {@link FindFirstNode} in the pipeline, this will hold copy of it. + */ + @Child private FindFirstNode directFindFirstNode; + + /** + * If there are some steps after the {@link FindFirstNode} in the cast pipeline, then this will + * hold copy of its first node (which can be chained to following nodes). + */ + @Child private CastNode afterFindFirst; + + protected BypassNode(PipelineConfigBuilder pcb, CastNode wrappedHead, FindFirstNode directFindFirstNode, CastNode afterFindFirst) { this.nullMapFn = pcb.getNullMapper(); this.isRNullBypassed = this.nullMapFn != null; this.nullMsg = pcb.getNullMessage() == null ? null : pcb.getNullMessage().fixCallObj(this); @@ -61,26 +92,42 @@ public abstract class BypassNode extends CastNode { assert this.nullMsg != null || this.isRNullBypassed; assert this.missingMsg != null || this.isRMissingBypassed; - this.directFindFirstNode = !isRNullBypassed || !isRMissingBypassed ? createDirectFindFirstNode(wrappedHead) : null; - this.useDirectFindFirstNode = directFindFirstNode != null; + this.directFindFirstNode = insertIfNotNull(directFindFirstNode); + this.afterFindFirst = insertIfNotNull(afterFindFirst); } - public static CastNode create(PipelineConfigBuilder pcb, CastNode wrappedHead) { - return BypassNodeGen.create(pcb, wrappedHead); - } - - public CastNode getWrappedHead() { + public final CastNode getWrappedHead() { return wrappedHead; } - public ArgumentMapper getNullMapper() { + public final ArgumentMapper getNullMapper() { return nullMapFn; } - public ArgumentMapper getMissingMapper() { + public final ArgumentMapper getMissingMapper() { return missingMapFn; } + protected final Object executeAfterFindFirst(Object value) { + if (directFindFirstNode != null) { + return afterFindFirst.execute(value); + } else { + return value; + } + } + + private Object executeFindFirstPipeline(Object value) { + Object result = directFindFirstNode.execute(value); + if (afterFindFirst != null) { + result = afterFindFirst.execute(result); + } + return result; + } + + private <T extends Node> T insertIfNotNull(T child) { + return child != null ? insert(child) : child; + } + @Specialization public Object bypassRNull(RNull x) { if (isRNullBypassed) { @@ -88,8 +135,8 @@ public abstract class BypassNode extends CastNode { handleArgumentWarning(x, nullMsg.callObj, nullMsg.message, nullMsg.args); } return nullMapFn.map(x); - } else if (useDirectFindFirstNode) { - return directFindFirstNode.execute(x); + } else if (directFindFirstNode != null) { + return executeFindFirstPipeline(x); } else { handleArgumentError(x, nullMsg.callObj, nullMsg.message, nullMsg.args); return x; @@ -103,8 +150,8 @@ public abstract class BypassNode extends CastNode { handleArgumentWarning(x, missingMsg.callObj, missingMsg.message, missingMsg.args); } return missingMapFn.map(x); - } else if (useDirectFindFirstNode) { - return directFindFirstNode.execute(x); + } else if (directFindFirstNode != null) { + return executeFindFirstPipeline(x); } else { handleArgumentError(x, missingMsg.callObj, missingMsg.message, missingMsg.args); return x; @@ -116,32 +163,90 @@ public abstract class BypassNode extends CastNode { return noHead ? x : wrappedHead.execute(x); } - static CastNode createDirectFindFirstNode(CastNode wrappedHead) { - ChainedCastNode parentFfh = null; - ChainedCastNode ffh = null; - - if (wrappedHead != null) { - CastNode cn = wrappedHead; - while (cn instanceof ChainedCastNode) { - ChainedCastNode chcn = (ChainedCastNode) cn; - if (chcn.getSecondCast() instanceof FindFirstNode) { - FindFirstNode ffn = (FindFirstNode) chcn.getSecondCast(); - if (ffn.getDefaultValue() != null) { - ffh = chcn; - } + /** + * Factory method that inspects the given cast pipeline and returns appropriate subclass of + * {@link BypassNode} possibly optimized for a pattern found in the pipeline. See + * {@link BypassNode} doc for details. + */ + public static CastNode create(PipelineConfigBuilder pcb, CastNode wrappedHead) { + if (wrappedHead == null) { + return BypassNodeGen.create(pcb, wrappedHead, null, null); + } + + // Here we traverse the cast chain looking for FindFirstNode, if we find it, we continue + // traversing to see if there is only single asXYVector step + boolean foundFindFirst = false; + FindFirstNode directFindFirstNode = null; + CastNode afterFindFirstNode = null; + ChainedCastNode previousCurrent = null; + CastNode current = wrappedHead; + Class singleCastBaseNodeClass = null; // represents the single asXYVector step + while (current instanceof ChainedCastNode) { + ChainedCastNode currentChained = (ChainedCastNode) current; + CastNode currentSecond = currentChained.getSecondCast(); + + if (!foundFindFirst && currentSecond instanceof FindFirstNode) { + foundFindFirst = true; + if (((FindFirstNode) currentSecond).getDefaultValue() != null) { + // we are only interested in 'findFirst' with some default value in order to map + // RNull/RMissing to it. + directFindFirstNode = (FindFirstNode) currentChained.getSecondCastFact().create(); + } + if (previousCurrent != null) { + afterFindFirstNode = previousCurrent.getSecondCastFact().create(); + } + } else if (foundFindFirst && currentSecond instanceof CastBaseNode) { + if (singleCastBaseNodeClass != null) { + singleCastBaseNodeClass = null; break; } - parentFfh = chcn; - cn = chcn.getFirstCast(); + singleCastBaseNodeClass = currentSecond.getClass(); } + + previousCurrent = currentChained; + current = currentChained.getFirstCast(); } - if (ffh == null) { - return null; - } else if (parentFfh == null) { - return ffh.getSecondCastFact().create(); + if (singleCastBaseNodeClass == null || !foundFindFirst) { + return BypassNodeGen.create(pcb, wrappedHead, directFindFirstNode, afterFindFirstNode); + } + + return createBypassByClass(pcb, wrappedHead, directFindFirstNode, afterFindFirstNode, singleCastBaseNodeClass); + } + + /** + * Depending on the {@code bypassClass} parameter creates corresponding {@code BypassXYNode} + * instance. + */ + private static BypassNode createBypassByClass(PipelineConfigBuilder pcb, CastNode wrappedHead, FindFirstNode directFindFirstNode, CastNode afterFindFirstNode, Class castNodeClass) { + if (castNodeClass == CastIntegerNode.class) { + return BypassIntegerNodeGen.create(pcb, wrappedHead, directFindFirstNode, afterFindFirstNode); + } else if (castNodeClass == CastDoubleBaseNode.class) { + return BypassDoubleNodeGen.create(pcb, wrappedHead, directFindFirstNode, afterFindFirstNode); } else { - return new ChainedCastNode(ffh.getSecondCastFact(), parentFfh.getSecondCastFact()); + return BypassNodeGen.create(pcb, wrappedHead, directFindFirstNode, afterFindFirstNode); + } + } + + protected abstract static class BypassIntegerNode extends BypassNode { + protected BypassIntegerNode(PipelineConfigBuilder pcb, CastNode wrappedHead, FindFirstNode directFindFirstNode, CastNode afterFindFirst) { + super(pcb, wrappedHead, directFindFirstNode, afterFindFirst); + } + + @Specialization + protected Object bypassInteger(int x) { + return executeAfterFindFirst(x); + } + } + + protected abstract static class BypassDoubleNode extends BypassNode { + protected BypassDoubleNode(PipelineConfigBuilder pcb, CastNode wrappedHead, FindFirstNode directFindFirstNode, CastNode afterFindFirst) { + super(pcb, wrappedHead, directFindFirstNode, afterFindFirst); + } + + @Specialization + protected Object bypassDouble(double x) { + return executeAfterFindFirst(x); } } } -- GitLab