Skip to content
Snippets Groups Projects
Commit 75791d68 authored by Lukas Stadler's avatar Lukas Stadler
Browse files

Merge pull request #532 in G/fastr from ~LUKAS.STADLER_ORACLE.COM/fastr:feature/match_arg to master

* commit '67da42bc':
  substitute match.arg
parents ac14ea4e 67da42bc
No related branches found
No related tags found
No related merge requests found
......@@ -507,6 +507,7 @@ public class BasePackage extends RBuiltinPackage {
add(MatMult.class, MatMult::create);
add(Match.class, MatchNodeGen::create);
add(MatchFun.class, MatchFunNodeGen::create);
add(MatchArg.class, MatchArgNodeGen::create);
add(Matrix.class, MatrixNodeGen::create);
add(Max.class, MaxNodeGen::create);
add(Mean.class, MeanNodeGen::create);
......@@ -742,7 +743,7 @@ public class BasePackage extends RBuiltinPackage {
addFastPath(baseFrame, "cbind", FastPathFactory.FORCED_EAGER_ARGS);
addFastPath(baseFrame, "rbind", FastPathFactory.FORCED_EAGER_ARGS);
setContainsDispatch(baseFrame, "sys.function", "match.arg", "eval", "[.data.frame", "[[.data.frame", "[<-.data.frame", "[[<-.data.frame");
setContainsDispatch(baseFrame, "sys.function", "eval", "[.data.frame", "[[.data.frame", "[<-.data.frame", "[[<-.data.frame");
}
private static void setContainsDispatch(MaterializedFrame baseFrame, String... functions) {
......
/*
* Copyright (c) 2016, 2016, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.stringValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.toBoolean;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.SUBSTITUTE;
import com.oracle.truffle.api.CallTarget;
import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.profiles.ConditionProfile;
import com.oracle.truffle.r.nodes.RRootNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.MatchArgNodeGen.MatchArgInternalNodeGen;
import com.oracle.truffle.r.nodes.function.FormalArguments;
import com.oracle.truffle.r.nodes.function.PromiseHelperNode;
import com.oracle.truffle.r.nodes.unary.CastNode;
import com.oracle.truffle.r.runtime.RArguments;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RFunction;
import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RPromise;
import com.oracle.truffle.r.runtime.data.RTypes;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.nodes.RNode;
import com.oracle.truffle.r.runtime.nodes.RSyntaxNode;
@RBuiltin(name = "match.arg", kind = SUBSTITUTE, parameterNames = {"arg", "choices", "several.ok"}, nonEvalArgs = {0}, behavior = COMPLEX)
public abstract class MatchArg extends RBuiltinNode {
@Override
public Object[] getDefaultParameterValues() {
return new Object[]{RMissing.instance, RMissing.instance, RRuntime.LOGICAL_FALSE};
}
@TypeSystemReference(RTypes.class)
protected abstract static class MatchArgInternal extends Node {
@Child private PMatch pmatch = PMatchNodeGen.create();
@Child private Identical identical = IdenticalNodeGen.create();
@Children private final CastNode[] casts;
{
CastBuilder builder = new CastBuilder();
builder.arg(0).allowNull().asStringVector();
builder.arg(1).allowMissing().mustBe(stringValue()).asStringVector();
builder.arg(2).mustBe(logicalValue()).asLogicalVector().findFirst().map(toBoolean());
this.casts = builder.getCasts();
}
public abstract Object execute(Object arg, Object choices, Object severalOK);
public final Object castAndExecute(Object arg, Object choices, Object severalOK) {
return execute(casts[0].execute(arg), casts[1].execute(choices), casts[2].execute(severalOK));
}
@Specialization
protected String matchArgNULL(@SuppressWarnings("unused") RNull arg, RAbstractStringVector choices, @SuppressWarnings("unused") boolean severalOK,
@Cached("createBinaryProfile()") ConditionProfile isEmptyProfile) {
return isEmptyProfile.profile(choices.getLength() == 0) ? RRuntime.STRING_NA : choices.getDataAt(0);
}
private void checkEmpty(RAbstractStringVector choices, int count) {
if (count == 0) {
CompilerDirectives.transferToInterpreter();
StringBuilder choicesString = new StringBuilder();
for (int i = 0; i < choices.getLength(); i++) {
choicesString.append(i == 0 ? "" : ", ").append(RRuntime.quoteString(choices.getDataAt(i), false));
}
throw RError.error(this, Message.ARG_ONE_OF, "arg", choicesString);
}
}
private static int count(RIntVector matched) {
int count = 0;
for (int i = 0; i < matched.getLength(); i++) {
if (matched.getDataAt(i) != -1) {
count++;
}
}
return count;
}
@Specialization(guards = "!severalOK")
protected String matchArg(RAbstractStringVector arg, RAbstractStringVector choices, @SuppressWarnings("unused") boolean severalOK) {
if (identical.executeByte(arg, choices, true, true, true, true, true) == RRuntime.LOGICAL_TRUE) {
return choices.getDataAt(0);
}
if (arg.getLength() != 1) {
CompilerDirectives.transferToInterpreter();
throw RError.error(this, Message.MUST_BE_SCALAR, "arg");
}
RIntVector matched = pmatch.execute(arg, choices, -1, true);
int count = count(matched);
checkEmpty(choices, count);
if (count > 1) {
CompilerDirectives.transferToInterpreter();
throw RError.error(this, Message.MORE_THAN_ONE_MATCH, "match.arg");
}
return choices.getDataAt(matched.getDataAt(0) - 1);
}
@Specialization(guards = "severalOK")
protected Object matchArgSeveral(RAbstractStringVector arg, RAbstractStringVector choices, @SuppressWarnings("unused") boolean severalOK) {
if (arg.getLength() == 0) {
CompilerDirectives.transferToInterpreter();
throw RError.error(this, Message.MUST_BE_GE_ONE, "arg");
}
RIntVector matched = pmatch.execute(arg, choices, -1, true);
int count = count(matched);
if (count == 1) {
return choices.getDataAt(matched.getDataAt(0) - 1);
}
checkEmpty(choices, count);
String[] result = new String[count];
for (int i = 0; i < matched.getLength(); i++) {
result[i] = choices.getDataAt(matched.getDataAt(i) - 1);
}
return RDataFactory.createStringVector(result, choices.isComplete());
}
}
protected static final class MatchArgChoices extends Node {
private final CallTarget target;
private final String symbol;
@Child private RNode value;
public MatchArgChoices(VirtualFrame frame, RPromise arg) {
CompilerAsserts.neverPartOfCompilation();
RFunction function = RArguments.getFunction(frame);
assert function.getRBuiltin() == null;
this.symbol = arg.getClosure().asSymbol();
if (symbol == null) {
throw RError.error(this, Message.INVALID_USE, "match.arg");
}
RRootNode def = (RRootNode) function.getRootNode();
this.target = function.getTarget();
FormalArguments arguments = def.getFormalArguments();
for (int i = 0; i < arguments.getLength(); i++) {
assert symbol == arguments.getSignature().getName(i) || !symbol.equals(arguments.getSignature().getName(i));
if (symbol == arguments.getSignature().getName(i)) {
RNode defaultArg = arguments.getDefaultArgument(i);
if (defaultArg == null) {
this.value = RContext.getASTBuilder().constant(RSyntaxNode.INTERNAL, RDataFactory.createEmptyStringVector()).asRNode();
}
this.value = RContext.getASTBuilder().process(defaultArg.asRSyntaxNode()).asRNode();
return;
}
}
throw RError.error(RError.SHOW_CALLER, Message.INVALID_USE, "match.arg");
}
public boolean isSupported(VirtualFrame frame, RPromise arg) {
return RArguments.getFunction(frame).getTarget() == target && arg.getClosure().asSymbol() == symbol;
}
public Object execute(VirtualFrame frame) {
return value.execute(frame);
}
}
protected static MatchArgInternal createInternal() {
return MatchArgInternalNodeGen.create();
}
@Specialization(limit = "3", guards = "choicesValue.isSupported(frame, arg)")
protected Object matchArg(VirtualFrame frame, RPromise arg, @SuppressWarnings("unused") RMissing choices, Object severalOK,
@Cached("new(frame, arg)") MatchArgChoices choicesValue,
@Cached("createInternal()") MatchArgInternal internal,
@Cached("new()") PromiseHelperNode promiseHelper) {
return internal.castAndExecute(promiseHelper.evaluate(frame, arg), choicesValue.execute(frame), severalOK);
}
protected static boolean isRMissing(Object value) {
return value instanceof RMissing;
}
@Specialization(guards = "!isRMissing(choices)")
protected Object matchArg(VirtualFrame frame, RPromise arg, Object choices, Object severalOK,
@Cached("createInternal()") MatchArgInternal internal,
@Cached("new()") PromiseHelperNode promiseHelper) {
return internal.castAndExecute(promiseHelper.evaluate(frame, arg), choices, severalOK);
}
@SuppressWarnings("unused")
@Fallback
protected Object matchArgFallback(Object arg, Object choices, Object severalOK) {
throw RError.error(this, Message.GENERIC, "too many different names in match.arg");
}
}
......@@ -43,6 +43,8 @@ public abstract class PMatch extends RBuiltinNode {
private final ConditionProfile nomatchNA = ConditionProfile.createBinaryProfile();
public abstract RIntVector execute(RAbstractStringVector x, RAbstractStringVector table, int nomatch, boolean duplicatesOk);
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("x").asStringVector();
......
......@@ -761,7 +761,9 @@ public final class RError extends RuntimeException {
GAP_MUST_BE_NON_NEGATIVE("'gap' must be non-negative integer"),
WRONG_PCRE_INFO("'pcre_fullinfo' returned '%d' "),
BAD_FUNCTION_EXPR("badly formed function expression"),
FIRST_ELEMENT_ONLY("only first element of '%s' argument used");
FIRST_ELEMENT_ONLY("only first element of '%s' argument used"),
MUST_BE_GE_ONE("'%s' must be of length >= 1"),
MORE_THAN_ONE_MATCH("there is more than one match in '%s'");
public final String message;
final boolean hasArgs;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment