From 21625308f84b845acfc12c37341b845e73456ee6 Mon Sep 17 00:00:00 2001 From: Adam Welc <adam.welc@oracle.com> Date: Thu, 25 Aug 2016 20:14:16 -0700 Subject: [PATCH] Rewritten parameter casts for standardGeneric builtin. --- .../r/nodes/builtin/base/StandardGeneric.java | 60 ++++++++++--------- .../com/oracle/truffle/r/test/S4/TestS4.java | 9 +++ 2 files changed, 41 insertions(+), 28 deletions(-) diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java index 753985e79b..48572c7a8c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java @@ -12,10 +12,13 @@ */ package com.oracle.truffle.r.nodes.builtin.base; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*; import static com.oracle.truffle.r.runtime.RVisibility.CUSTOM; import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; +import java.util.function.Function; + import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.MaterializedFrame; @@ -28,6 +31,8 @@ import com.oracle.truffle.r.nodes.attributes.AttributeAccess; import com.oracle.truffle.r.nodes.attributes.AttributeAccessNodeGen; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.function.ClassHierarchyScalarNode; +import com.oracle.truffle.r.nodes.function.ClassHierarchyScalarNodeGen; import com.oracle.truffle.r.nodes.objects.CollectGenericArgumentsNode; import com.oracle.truffle.r.nodes.objects.CollectGenericArgumentsNodeGen; import com.oracle.truffle.r.nodes.objects.DispatchGeneric; @@ -39,15 +44,12 @@ import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.context.RContext; -import com.oracle.truffle.r.runtime.data.RAttributable; -import com.oracle.truffle.r.runtime.data.RAttributeProfiles; import com.oracle.truffle.r.runtime.data.RAttributes; import com.oracle.truffle.r.runtime.data.RFunction; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.env.REnvironment; // transcribed from src/main/objects.c @@ -63,6 +65,7 @@ public abstract class StandardGeneric extends RBuiltinNode { @Child private LocalReadVariableNode readSigARgs = LocalReadVariableNode.create(RRuntime.DOT_SIG_ARGS, true); @Child private CollectGenericArgumentsNode collectArgumentsNode; @Child private DispatchGeneric dispatchGeneric = DispatchGenericNodeGen.create(); + @Child private ClassHierarchyScalarNode classNode; @Child private CastNode castIntScalar; @Child private CastNode castStringScalar; @@ -77,10 +80,22 @@ public abstract class StandardGeneric extends RBuiltinNode { private final BranchProfile noGenFunFound = BranchProfile.create(); private final ConditionProfile sameNamesProfile = ConditionProfile.createBinaryProfile(); - private final RAttributeProfiles attrProfiles = RAttributeProfiles.create(); + private String argClass(Object arg) { + if (classNode == null) { + classNode = insert(ClassHierarchyScalarNodeGen.create()); + } + return classNode.executeString(arg); + } + + @Override + protected void createCasts(CastBuilder casts) { + casts.arg("f").defaultError(RError.Message.GENERIC, "argument to 'standardGeneric' must be a non-empty character string").mustBe( + stringValue()).asStringVector().findFirst().mustBe(lengthGt(0)); + Function<Object, Object> argClass = this::argClass; + casts.arg("fdef").asAttributable(true, true, true).mustBe(missingValue().or(instanceOf(RFunction.class)), RError.SHOW_CALLER, RError.Message.EXPECTED_GENERIC, argClass); + } - private Object stdGenericInternal(VirtualFrame frame, RAbstractStringVector fVec, RFunction fdef) { - String fname = fVec.getDataAt(0); + private Object stdGenericInternal(VirtualFrame frame, String fname, RFunction fdef) { MaterializedFrame fnFrame = fdef.getEnclosingFrame(); REnvironment mtable = (REnvironment) readMTableFirst.execute(frame, fnFrame); if (mtable == null) { @@ -105,7 +120,7 @@ public abstract class StandardGeneric extends RBuiltinNode { return ret; } - private Object getFunction(VirtualFrame frame, RAbstractStringVector fVec, String fname, Object fnObj) { + private Object getFunction(VirtualFrame frame, String fname, Object fnObj) { if (fnObj == RNull.instance) { noGenFunFound.enter(); return null; @@ -128,7 +143,7 @@ public abstract class StandardGeneric extends RBuiltinNode { } String gen = (String) castStringScalar.execute(genObj); if (sameNamesProfile.profile(gen == fname)) { - return stdGenericInternal(frame, fVec, fn); + return stdGenericInternal(frame, fname, fn); } else { // in many cases == is good enough (and this will be the fastest path), but it's not // always sufficient @@ -136,21 +151,20 @@ public abstract class StandardGeneric extends RBuiltinNode { noGenFunFound.enter(); return null; } - return stdGenericInternal(frame, fVec, fn); + return stdGenericInternal(frame, fname, fn); } } - @Specialization(guards = "fVec.getLength() > 0") - protected Object stdGeneric(VirtualFrame frame, RAbstractStringVector fVec, RFunction fdef) { - return stdGenericInternal(frame, fVec, fdef); + @Specialization + protected Object stdGeneric(VirtualFrame frame, String fname, RFunction fdef) { + return stdGenericInternal(frame, fname, fdef); } - @Specialization(guards = "fVec.getLength() > 0") - protected Object stdGeneric(VirtualFrame frame, RAbstractStringVector fVec, @SuppressWarnings("unused") RMissing fdef) { - String fname = fVec.getDataAt(0); + @Specialization + protected Object stdGeneric(VirtualFrame frame, String fname, @SuppressWarnings("unused") RMissing fdef) { int n = RArguments.getDepth(frame); Object fnObj = RArguments.getFunction(frame); - fnObj = getFunction(frame, fVec, fname, fnObj); + fnObj = getFunction(frame, fname, fnObj); if (fnObj != null) { return fnObj; } @@ -161,10 +175,10 @@ public abstract class StandardGeneric extends RBuiltinNode { } // TODO: GNU R counts to (i < n) - does their equivalent of getDepth return a different // value - // TODO; shouldn't we count from n to 0? + // TODO: shouldn't we count from n to 0? for (int i = 0; i <= n; i++) { fnObj = sysFunction.executeObject(frame, i); - fnObj = getFunction(frame, fVec, fname, fnObj); + fnObj = getFunction(frame, fname, fnObj); if (fnObj != null) { return fnObj; } @@ -172,14 +186,4 @@ public abstract class StandardGeneric extends RBuiltinNode { throw RError.error(this, RError.Message.STD_GENERIC_WRONG_CALL, fname); } - @Specialization - protected Object stdGeneric(Object fVec, RAttributable fdef) { - if (!(fVec instanceof String || (fVec instanceof RAbstractStringVector && ((RAbstractStringVector) fVec).getLength() > 0))) { - throw RError.error(this, RError.Message.GENERIC, "argument to 'standardGeneric' must be a non-empty character string"); - } else { - RStringVector cl = fdef.getClassAttr(attrProfiles); - // not a GNU R error message - throw RError.error(this, RError.Message.EXPECTED_GENERIC, cl.getLength() == 0 ? RRuntime.STRING_NA : cl.getDataAt(0)); - } - } } diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java index d903558850..41a29e7c30 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java @@ -112,6 +112,15 @@ public class TestS4 extends TestRBase { } + @Test + public void testStdGeneric() { + assertEval("{ standardGeneric(42) }"); + assertEval("{ standardGeneric(character()) }"); + assertEval("{ standardGeneric(\"\") }"); + assertEval("{ standardGeneric(\"foo\", 42) }"); + assertEval("{ x<-42; class(x)<-character(); standardGeneric(\"foo\", x) }"); + } + @Override public String getTestDir() { return "S4"; -- GitLab