Skip to content
Snippets Groups Projects
Commit 21625308 authored by Adam Welc's avatar Adam Welc
Browse files

Rewritten parameter casts for standardGeneric builtin.

parent 5c43cd2e
No related branches found
No related tags found
No related merge requests found
......@@ -12,10 +12,13 @@
*/
package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*;
import static com.oracle.truffle.r.runtime.RVisibility.CUSTOM;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
import java.util.function.Function;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.MaterializedFrame;
......@@ -28,6 +31,8 @@ import com.oracle.truffle.r.nodes.attributes.AttributeAccess;
import com.oracle.truffle.r.nodes.attributes.AttributeAccessNodeGen;
import com.oracle.truffle.r.nodes.builtin.CastBuilder;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyScalarNode;
import com.oracle.truffle.r.nodes.function.ClassHierarchyScalarNodeGen;
import com.oracle.truffle.r.nodes.objects.CollectGenericArgumentsNode;
import com.oracle.truffle.r.nodes.objects.CollectGenericArgumentsNodeGen;
import com.oracle.truffle.r.nodes.objects.DispatchGeneric;
......@@ -39,15 +44,12 @@ import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.context.RContext;
import com.oracle.truffle.r.runtime.data.RAttributable;
import com.oracle.truffle.r.runtime.data.RAttributeProfiles;
import com.oracle.truffle.r.runtime.data.RAttributes;
import com.oracle.truffle.r.runtime.data.RFunction;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
import com.oracle.truffle.r.runtime.env.REnvironment;
// transcribed from src/main/objects.c
......@@ -63,6 +65,7 @@ public abstract class StandardGeneric extends RBuiltinNode {
@Child private LocalReadVariableNode readSigARgs = LocalReadVariableNode.create(RRuntime.DOT_SIG_ARGS, true);
@Child private CollectGenericArgumentsNode collectArgumentsNode;
@Child private DispatchGeneric dispatchGeneric = DispatchGenericNodeGen.create();
@Child private ClassHierarchyScalarNode classNode;
@Child private CastNode castIntScalar;
@Child private CastNode castStringScalar;
......@@ -77,10 +80,22 @@ public abstract class StandardGeneric extends RBuiltinNode {
private final BranchProfile noGenFunFound = BranchProfile.create();
private final ConditionProfile sameNamesProfile = ConditionProfile.createBinaryProfile();
private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
private String argClass(Object arg) {
if (classNode == null) {
classNode = insert(ClassHierarchyScalarNodeGen.create());
}
return classNode.executeString(arg);
}
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("f").defaultError(RError.Message.GENERIC, "argument to 'standardGeneric' must be a non-empty character string").mustBe(
stringValue()).asStringVector().findFirst().mustBe(lengthGt(0));
Function<Object, Object> argClass = this::argClass;
casts.arg("fdef").asAttributable(true, true, true).mustBe(missingValue().or(instanceOf(RFunction.class)), RError.SHOW_CALLER, RError.Message.EXPECTED_GENERIC, argClass);
}
private Object stdGenericInternal(VirtualFrame frame, RAbstractStringVector fVec, RFunction fdef) {
String fname = fVec.getDataAt(0);
private Object stdGenericInternal(VirtualFrame frame, String fname, RFunction fdef) {
MaterializedFrame fnFrame = fdef.getEnclosingFrame();
REnvironment mtable = (REnvironment) readMTableFirst.execute(frame, fnFrame);
if (mtable == null) {
......@@ -105,7 +120,7 @@ public abstract class StandardGeneric extends RBuiltinNode {
return ret;
}
private Object getFunction(VirtualFrame frame, RAbstractStringVector fVec, String fname, Object fnObj) {
private Object getFunction(VirtualFrame frame, String fname, Object fnObj) {
if (fnObj == RNull.instance) {
noGenFunFound.enter();
return null;
......@@ -128,7 +143,7 @@ public abstract class StandardGeneric extends RBuiltinNode {
}
String gen = (String) castStringScalar.execute(genObj);
if (sameNamesProfile.profile(gen == fname)) {
return stdGenericInternal(frame, fVec, fn);
return stdGenericInternal(frame, fname, fn);
} else {
// in many cases == is good enough (and this will be the fastest path), but it's not
// always sufficient
......@@ -136,21 +151,20 @@ public abstract class StandardGeneric extends RBuiltinNode {
noGenFunFound.enter();
return null;
}
return stdGenericInternal(frame, fVec, fn);
return stdGenericInternal(frame, fname, fn);
}
}
@Specialization(guards = "fVec.getLength() > 0")
protected Object stdGeneric(VirtualFrame frame, RAbstractStringVector fVec, RFunction fdef) {
return stdGenericInternal(frame, fVec, fdef);
@Specialization
protected Object stdGeneric(VirtualFrame frame, String fname, RFunction fdef) {
return stdGenericInternal(frame, fname, fdef);
}
@Specialization(guards = "fVec.getLength() > 0")
protected Object stdGeneric(VirtualFrame frame, RAbstractStringVector fVec, @SuppressWarnings("unused") RMissing fdef) {
String fname = fVec.getDataAt(0);
@Specialization
protected Object stdGeneric(VirtualFrame frame, String fname, @SuppressWarnings("unused") RMissing fdef) {
int n = RArguments.getDepth(frame);
Object fnObj = RArguments.getFunction(frame);
fnObj = getFunction(frame, fVec, fname, fnObj);
fnObj = getFunction(frame, fname, fnObj);
if (fnObj != null) {
return fnObj;
}
......@@ -161,10 +175,10 @@ public abstract class StandardGeneric extends RBuiltinNode {
}
// TODO: GNU R counts to (i < n) - does their equivalent of getDepth return a different
// value
// TODO; shouldn't we count from n to 0?
// TODO: shouldn't we count from n to 0?
for (int i = 0; i <= n; i++) {
fnObj = sysFunction.executeObject(frame, i);
fnObj = getFunction(frame, fVec, fname, fnObj);
fnObj = getFunction(frame, fname, fnObj);
if (fnObj != null) {
return fnObj;
}
......@@ -172,14 +186,4 @@ public abstract class StandardGeneric extends RBuiltinNode {
throw RError.error(this, RError.Message.STD_GENERIC_WRONG_CALL, fname);
}
@Specialization
protected Object stdGeneric(Object fVec, RAttributable fdef) {
if (!(fVec instanceof String || (fVec instanceof RAbstractStringVector && ((RAbstractStringVector) fVec).getLength() > 0))) {
throw RError.error(this, RError.Message.GENERIC, "argument to 'standardGeneric' must be a non-empty character string");
} else {
RStringVector cl = fdef.getClassAttr(attrProfiles);
// not a GNU R error message
throw RError.error(this, RError.Message.EXPECTED_GENERIC, cl.getLength() == 0 ? RRuntime.STRING_NA : cl.getDataAt(0));
}
}
}
......@@ -112,6 +112,15 @@ public class TestS4 extends TestRBase {
}
@Test
public void testStdGeneric() {
assertEval("{ standardGeneric(42) }");
assertEval("{ standardGeneric(character()) }");
assertEval("{ standardGeneric(\"\") }");
assertEval("{ standardGeneric(\"foo\", 42) }");
assertEval("{ x<-42; class(x)<-character(); standardGeneric(\"foo\", x) }");
}
@Override
public String getTestDir() {
return "S4";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment