From 21625308f84b845acfc12c37341b845e73456ee6 Mon Sep 17 00:00:00 2001
From: Adam Welc <adam.welc@oracle.com>
Date: Thu, 25 Aug 2016 20:14:16 -0700
Subject: [PATCH] Rewritten parameter casts for standardGeneric builtin.

---
 .../r/nodes/builtin/base/StandardGeneric.java | 60 ++++++++++---------
 .../com/oracle/truffle/r/test/S4/TestS4.java  |  9 +++
 2 files changed, 41 insertions(+), 28 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java
index 753985e79b..48572c7a8c 100644
--- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java
+++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/StandardGeneric.java
@@ -12,10 +12,13 @@
  */
 package com.oracle.truffle.r.nodes.builtin.base;
 
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.*;
 import static com.oracle.truffle.r.runtime.RVisibility.CUSTOM;
 import static com.oracle.truffle.r.runtime.builtins.RBehavior.COMPLEX;
 import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE;
 
+import java.util.function.Function;
+
 import com.oracle.truffle.api.CompilerDirectives;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.frame.MaterializedFrame;
@@ -28,6 +31,8 @@ import com.oracle.truffle.r.nodes.attributes.AttributeAccess;
 import com.oracle.truffle.r.nodes.attributes.AttributeAccessNodeGen;
 import com.oracle.truffle.r.nodes.builtin.CastBuilder;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
+import com.oracle.truffle.r.nodes.function.ClassHierarchyScalarNode;
+import com.oracle.truffle.r.nodes.function.ClassHierarchyScalarNodeGen;
 import com.oracle.truffle.r.nodes.objects.CollectGenericArgumentsNode;
 import com.oracle.truffle.r.nodes.objects.CollectGenericArgumentsNodeGen;
 import com.oracle.truffle.r.nodes.objects.DispatchGeneric;
@@ -39,15 +44,12 @@ import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
 import com.oracle.truffle.r.runtime.context.RContext;
-import com.oracle.truffle.r.runtime.data.RAttributable;
-import com.oracle.truffle.r.runtime.data.RAttributeProfiles;
 import com.oracle.truffle.r.runtime.data.RAttributes;
 import com.oracle.truffle.r.runtime.data.RFunction;
 import com.oracle.truffle.r.runtime.data.RList;
 import com.oracle.truffle.r.runtime.data.RMissing;
 import com.oracle.truffle.r.runtime.data.RNull;
 import com.oracle.truffle.r.runtime.data.RStringVector;
-import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
 import com.oracle.truffle.r.runtime.env.REnvironment;
 
 // transcribed from src/main/objects.c
@@ -63,6 +65,7 @@ public abstract class StandardGeneric extends RBuiltinNode {
     @Child private LocalReadVariableNode readSigARgs = LocalReadVariableNode.create(RRuntime.DOT_SIG_ARGS, true);
     @Child private CollectGenericArgumentsNode collectArgumentsNode;
     @Child private DispatchGeneric dispatchGeneric = DispatchGenericNodeGen.create();
+    @Child private ClassHierarchyScalarNode classNode;
 
     @Child private CastNode castIntScalar;
     @Child private CastNode castStringScalar;
@@ -77,10 +80,22 @@ public abstract class StandardGeneric extends RBuiltinNode {
     private final BranchProfile noGenFunFound = BranchProfile.create();
     private final ConditionProfile sameNamesProfile = ConditionProfile.createBinaryProfile();
 
-    private final RAttributeProfiles attrProfiles = RAttributeProfiles.create();
+    private String argClass(Object arg) {
+        if (classNode == null) {
+            classNode = insert(ClassHierarchyScalarNodeGen.create());
+        }
+        return classNode.executeString(arg);
+    }
+
+    @Override
+    protected void createCasts(CastBuilder casts) {
+        casts.arg("f").defaultError(RError.Message.GENERIC, "argument to 'standardGeneric' must be a non-empty character string").mustBe(
+                        stringValue()).asStringVector().findFirst().mustBe(lengthGt(0));
+        Function<Object, Object> argClass = this::argClass;
+        casts.arg("fdef").asAttributable(true, true, true).mustBe(missingValue().or(instanceOf(RFunction.class)), RError.SHOW_CALLER, RError.Message.EXPECTED_GENERIC, argClass);
+    }
 
-    private Object stdGenericInternal(VirtualFrame frame, RAbstractStringVector fVec, RFunction fdef) {
-        String fname = fVec.getDataAt(0);
+    private Object stdGenericInternal(VirtualFrame frame, String fname, RFunction fdef) {
         MaterializedFrame fnFrame = fdef.getEnclosingFrame();
         REnvironment mtable = (REnvironment) readMTableFirst.execute(frame, fnFrame);
         if (mtable == null) {
@@ -105,7 +120,7 @@ public abstract class StandardGeneric extends RBuiltinNode {
         return ret;
     }
 
-    private Object getFunction(VirtualFrame frame, RAbstractStringVector fVec, String fname, Object fnObj) {
+    private Object getFunction(VirtualFrame frame, String fname, Object fnObj) {
         if (fnObj == RNull.instance) {
             noGenFunFound.enter();
             return null;
@@ -128,7 +143,7 @@ public abstract class StandardGeneric extends RBuiltinNode {
         }
         String gen = (String) castStringScalar.execute(genObj);
         if (sameNamesProfile.profile(gen == fname)) {
-            return stdGenericInternal(frame, fVec, fn);
+            return stdGenericInternal(frame, fname, fn);
         } else {
             // in many cases == is good enough (and this will be the fastest path), but it's not
             // always sufficient
@@ -136,21 +151,20 @@ public abstract class StandardGeneric extends RBuiltinNode {
                 noGenFunFound.enter();
                 return null;
             }
-            return stdGenericInternal(frame, fVec, fn);
+            return stdGenericInternal(frame, fname, fn);
         }
     }
 
-    @Specialization(guards = "fVec.getLength() > 0")
-    protected Object stdGeneric(VirtualFrame frame, RAbstractStringVector fVec, RFunction fdef) {
-        return stdGenericInternal(frame, fVec, fdef);
+    @Specialization
+    protected Object stdGeneric(VirtualFrame frame, String fname, RFunction fdef) {
+        return stdGenericInternal(frame, fname, fdef);
     }
 
-    @Specialization(guards = "fVec.getLength() > 0")
-    protected Object stdGeneric(VirtualFrame frame, RAbstractStringVector fVec, @SuppressWarnings("unused") RMissing fdef) {
-        String fname = fVec.getDataAt(0);
+    @Specialization
+    protected Object stdGeneric(VirtualFrame frame, String fname, @SuppressWarnings("unused") RMissing fdef) {
         int n = RArguments.getDepth(frame);
         Object fnObj = RArguments.getFunction(frame);
-        fnObj = getFunction(frame, fVec, fname, fnObj);
+        fnObj = getFunction(frame, fname, fnObj);
         if (fnObj != null) {
             return fnObj;
         }
@@ -161,10 +175,10 @@ public abstract class StandardGeneric extends RBuiltinNode {
         }
         // TODO: GNU R counts to (i < n) - does their equivalent of getDepth return a different
         // value
-        // TODO; shouldn't we count from n to 0?
+        // TODO: shouldn't we count from n to 0?
         for (int i = 0; i <= n; i++) {
             fnObj = sysFunction.executeObject(frame, i);
-            fnObj = getFunction(frame, fVec, fname, fnObj);
+            fnObj = getFunction(frame, fname, fnObj);
             if (fnObj != null) {
                 return fnObj;
             }
@@ -172,14 +186,4 @@ public abstract class StandardGeneric extends RBuiltinNode {
         throw RError.error(this, RError.Message.STD_GENERIC_WRONG_CALL, fname);
     }
 
-    @Specialization
-    protected Object stdGeneric(Object fVec, RAttributable fdef) {
-        if (!(fVec instanceof String || (fVec instanceof RAbstractStringVector && ((RAbstractStringVector) fVec).getLength() > 0))) {
-            throw RError.error(this, RError.Message.GENERIC, "argument to 'standardGeneric' must be a non-empty character string");
-        } else {
-            RStringVector cl = fdef.getClassAttr(attrProfiles);
-            // not a GNU R error message
-            throw RError.error(this, RError.Message.EXPECTED_GENERIC, cl.getLength() == 0 ? RRuntime.STRING_NA : cl.getDataAt(0));
-        }
-    }
 }
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java
index d903558850..41a29e7c30 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/S4/TestS4.java
@@ -112,6 +112,15 @@ public class TestS4 extends TestRBase {
 
     }
 
+    @Test
+    public void testStdGeneric() {
+        assertEval("{ standardGeneric(42) }");
+        assertEval("{ standardGeneric(character()) }");
+        assertEval("{ standardGeneric(\"\") }");
+        assertEval("{ standardGeneric(\"foo\", 42) }");
+        assertEval("{ x<-42; class(x)<-character(); standardGeneric(\"foo\", x) }");
+    }
+
     @Override
     public String getTestDir() {
         return "S4";
-- 
GitLab