diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java index 13458cab6101bc8ffaf836f50d4fe4fb782b79e2..8b3b8edcc9c2ef5dc130f5364527aae53d600f33 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/BasePackage.java @@ -507,6 +507,7 @@ public class BasePackage extends RBuiltinPackage { add(MatMult.class, MatMult::create); add(Match.class, MatchNodeGen::create); add(MatchFun.class, MatchFunNodeGen::create); + add(MatchArg.class, MatchArgNodeGen::create); add(Matrix.class, MatrixNodeGen::create); add(Max.class, MaxNodeGen::create); add(Mean.class, MeanNodeGen::create); @@ -742,7 +743,7 @@ public class BasePackage extends RBuiltinPackage { addFastPath(baseFrame, "cbind", FastPathFactory.FORCED_EAGER_ARGS); addFastPath(baseFrame, "rbind", FastPathFactory.FORCED_EAGER_ARGS); - setContainsDispatch(baseFrame, "sys.function", "match.arg", "eval", "[.data.frame", "[[.data.frame", "[<-.data.frame", "[[<-.data.frame"); + setContainsDispatch(baseFrame, "sys.function", "eval", "[.data.frame", "[[.data.frame", "[<-.data.frame", "[[<-.data.frame"); } private static void setContainsDispatch(MaterializedFrame baseFrame, String... functions) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatchArg.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatchArg.java new file mode 100644 index 0000000000000000000000000000000000000000..bc6405b87dad17cc3a56e2b301868c6686a84bbc --- /dev/null +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatchArg.java @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.r.nodes.builtin.base; + +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean; +import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX; +import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.SUBSTITUTE; + +import com.oracle.truffle.api.CallTarget; +import com.oracle.truffle.api.CompilerAsserts; +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.dsl.TypeSystemReference; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.RRootNode; +import com.oracle.truffle.r.nodes.builtin.CastBuilder; +import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.MatchArgNodeGen.MatchArgInternalNodeGen; +import com.oracle.truffle.r.nodes.function.FormalArguments; +import com.oracle.truffle.r.nodes.function.PromiseHelperNode; +import com.oracle.truffle.r.nodes.unary.CastNode; +import com.oracle.truffle.r.runtime.RArguments; +import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; +import com.oracle.truffle.r.runtime.RRuntime; +import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import com.oracle.truffle.r.runtime.context.RContext; +import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RFunction; +import com.oracle.truffle.r.runtime.data.RIntVector; +import com.oracle.truffle.r.runtime.data.RMissing; +import com.oracle.truffle.r.runtime.data.RNull; +import com.oracle.truffle.r.runtime.data.RPromise; +import com.oracle.truffle.r.runtime.data.RTypes; +import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; +import com.oracle.truffle.r.runtime.nodes.RNode; +import com.oracle.truffle.r.runtime.nodes.RSyntaxNode; + +@RBuiltin(name = "match.arg", kind = SUBSTITUTE, parameterNames = {"arg", "choices", "several.ok"}, nonEvalArgs = {0}, behavior = COMPLEX) +public abstract class MatchArg extends RBuiltinNode { + + @Override + public Object[] getDefaultParameterValues() { + return new Object[]{RMissing.instance, RMissing.instance, RRuntime.LOGICAL_FALSE}; + } + + @TypeSystemReference(RTypes.class) + protected abstract static class MatchArgInternal extends Node { + + @Child private PMatch pmatch = PMatchNodeGen.create(); + @Child private Identical identical = IdenticalNodeGen.create(); + + @Children private final CastNode[] casts; + + { + CastBuilder builder = new CastBuilder(); + builder.arg(0).allowNull().asStringVector(); + builder.arg(1).allowMissing().mustBe(stringValue()).asStringVector(); + builder.arg(2).mustBe(logicalValue()).asLogicalVector().findFirst().map(toBoolean()); + this.casts = builder.getCasts(); + } + + public abstract Object execute(Object arg, Object choices, Object severalOK); + + public final Object castAndExecute(Object arg, Object choices, Object severalOK) { + return execute(casts[0].execute(arg), casts[1].execute(choices), casts[2].execute(severalOK)); + } + + @Specialization + protected String matchArgNULL(@SuppressWarnings("unused") RNull arg, RAbstractStringVector choices, @SuppressWarnings("unused") boolean severalOK, + @Cached("createBinaryProfile()") ConditionProfile isEmptyProfile) { + return isEmptyProfile.profile(choices.getLength() == 0) ? RRuntime.STRING_NA : choices.getDataAt(0); + } + + private void checkEmpty(RAbstractStringVector choices, int count) { + if (count == 0) { + CompilerDirectives.transferToInterpreter(); + StringBuilder choicesString = new StringBuilder(); + for (int i = 0; i < choices.getLength(); i++) { + choicesString.append(i == 0 ? "" : ", ").append(RRuntime.quoteString(choices.getDataAt(i), false)); + } + throw RError.error(this, Message.ARG_ONE_OF, "arg", choicesString); + } + } + + private static int count(RIntVector matched) { + int count = 0; + for (int i = 0; i < matched.getLength(); i++) { + if (matched.getDataAt(i) != -1) { + count++; + } + } + return count; + } + + @Specialization(guards = "!severalOK") + protected String matchArg(RAbstractStringVector arg, RAbstractStringVector choices, @SuppressWarnings("unused") boolean severalOK) { + if (identical.executeByte(arg, choices, true, true, true, true, true) == RRuntime.LOGICAL_TRUE) { + return choices.getDataAt(0); + } + if (arg.getLength() != 1) { + CompilerDirectives.transferToInterpreter(); + throw RError.error(this, Message.MUST_BE_SCALAR, "arg"); + } + RIntVector matched = pmatch.execute(arg, choices, -1, true); + int count = count(matched); + checkEmpty(choices, count); + if (count > 1) { + CompilerDirectives.transferToInterpreter(); + throw RError.error(this, Message.MORE_THAN_ONE_MATCH, "match.arg"); + } + return choices.getDataAt(matched.getDataAt(0) - 1); + } + + @Specialization(guards = "severalOK") + protected Object matchArgSeveral(RAbstractStringVector arg, RAbstractStringVector choices, @SuppressWarnings("unused") boolean severalOK) { + if (arg.getLength() == 0) { + CompilerDirectives.transferToInterpreter(); + throw RError.error(this, Message.MUST_BE_GE_ONE, "arg"); + } + RIntVector matched = pmatch.execute(arg, choices, -1, true); + int count = count(matched); + if (count == 1) { + return choices.getDataAt(matched.getDataAt(0) - 1); + } + checkEmpty(choices, count); + String[] result = new String[count]; + for (int i = 0; i < matched.getLength(); i++) { + result[i] = choices.getDataAt(matched.getDataAt(i) - 1); + } + return RDataFactory.createStringVector(result, choices.isComplete()); + } + } + + protected static final class MatchArgChoices extends Node { + + private final CallTarget target; + private final String symbol; + + @Child private RNode value; + + public MatchArgChoices(VirtualFrame frame, RPromise arg) { + CompilerAsserts.neverPartOfCompilation(); + + RFunction function = RArguments.getFunction(frame); + assert function.getRBuiltin() == null; + + this.symbol = arg.getClosure().asSymbol(); + if (symbol == null) { + throw RError.error(this, Message.INVALID_USE, "match.arg"); + } + + RRootNode def = (RRootNode) function.getRootNode(); + this.target = function.getTarget(); + FormalArguments arguments = def.getFormalArguments(); + + for (int i = 0; i < arguments.getLength(); i++) { + assert symbol == arguments.getSignature().getName(i) || !symbol.equals(arguments.getSignature().getName(i)); + if (symbol == arguments.getSignature().getName(i)) { + RNode defaultArg = arguments.getDefaultArgument(i); + if (defaultArg == null) { + this.value = RContext.getASTBuilder().constant(RSyntaxNode.INTERNAL, RDataFactory.createEmptyStringVector()).asRNode(); + } + this.value = RContext.getASTBuilder().process(defaultArg.asRSyntaxNode()).asRNode(); + return; + } + } + throw RError.error(RError.SHOW_CALLER, Message.INVALID_USE, "match.arg"); + } + + public boolean isSupported(VirtualFrame frame, RPromise arg) { + return RArguments.getFunction(frame).getTarget() == target && arg.getClosure().asSymbol() == symbol; + } + + public Object execute(VirtualFrame frame) { + return value.execute(frame); + } + } + + protected static MatchArgInternal createInternal() { + return MatchArgInternalNodeGen.create(); + } + + @Specialization(limit = "3", guards = "choicesValue.isSupported(frame, arg)") + protected Object matchArg(VirtualFrame frame, RPromise arg, @SuppressWarnings("unused") RMissing choices, Object severalOK, + @Cached("new(frame, arg)") MatchArgChoices choicesValue, + @Cached("createInternal()") MatchArgInternal internal, + @Cached("new()") PromiseHelperNode promiseHelper) { + return internal.castAndExecute(promiseHelper.evaluate(frame, arg), choicesValue.execute(frame), severalOK); + } + + protected static boolean isRMissing(Object value) { + return value instanceof RMissing; + } + + @Specialization(guards = "!isRMissing(choices)") + protected Object matchArg(VirtualFrame frame, RPromise arg, Object choices, Object severalOK, + @Cached("createInternal()") MatchArgInternal internal, + @Cached("new()") PromiseHelperNode promiseHelper) { + return internal.castAndExecute(promiseHelper.evaluate(frame, arg), choices, severalOK); + } + + @SuppressWarnings("unused") + @Fallback + protected Object matchArgFallback(Object arg, Object choices, Object severalOK) { + throw RError.error(this, Message.GENERIC, "too many different names in match.arg"); + } +} diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/PMatch.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/PMatch.java index 4f841fb44d50e02358d9a6f652bc4736735e83ee..ad4cbd55f2405b4a34761c15835f1a5b45127e3f 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/PMatch.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/PMatch.java @@ -43,6 +43,8 @@ public abstract class PMatch extends RBuiltinNode { private final ConditionProfile nomatchNA = ConditionProfile.createBinaryProfile(); + public abstract RIntVector execute(RAbstractStringVector x, RAbstractStringVector table, int nomatch, boolean duplicatesOk); + @Override protected void createCasts(CastBuilder casts) { casts.arg("x").asStringVector(); diff --git a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java index 4fcaa6b519709b6153d65cf5d29636089e3e9828..9337e810d8a898aa8f6c1257525fdbbbdce62cdb 100644 --- a/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java +++ b/com.oracle.truffle.r.runtime/src/com/oracle/truffle/r/runtime/RError.java @@ -761,7 +761,9 @@ public final class RError extends RuntimeException { GAP_MUST_BE_NON_NEGATIVE("'gap' must be non-negative integer"), WRONG_PCRE_INFO("'pcre_fullinfo' returned '%d' "), BAD_FUNCTION_EXPR("badly formed function expression"), - FIRST_ELEMENT_ONLY("only first element of '%s' argument used"); + FIRST_ELEMENT_ONLY("only first element of '%s' argument used"), + MUST_BE_GE_ONE("'%s' must be of length >= 1"), + MORE_THAN_ONE_MATCH("there is more than one match in '%s'"); public final String message; final boolean hasArgs;