From 098c1137b06ebe26b37f69888b0583e4224e121d Mon Sep 17 00:00:00 2001
From: Zbynek Slajchrt <zbynek.slajchrt@oracle.com>
Date: Fri, 9 Sep 2016 11:50:25 +0200
Subject: [PATCH] rbdiag supports external builtins (no chimney-sweeping atm)

---
 .../truffle/r/nodes/test/ChimneySweeping.java |  20 +-
 .../r/nodes/test/RBuiltinDiagnostics.java     | 171 +++++++++++++++---
 mx.fastr/mx_fastr.py                          |   2 +
 3 files changed, 160 insertions(+), 33 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java
index 3c1dd496b8..57c5f03a4f 100644
--- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java
+++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java
@@ -44,9 +44,12 @@ import com.oracle.truffle.api.dsl.UnsupportedSpecializationException;
 import com.oracle.truffle.api.vm.PolyglotEngine;
 import com.oracle.truffle.api.vm.PolyglotEngine.Value;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinFactory;
+import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
 import com.oracle.truffle.r.nodes.casts.CastNodeSampler;
 import com.oracle.truffle.r.nodes.casts.Samples;
 import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.DiagConfig;
+import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.RBuiltinDiagFactory;
+import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.RIntBuiltinDiagFactory;
 import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.SingleBuiltinDiagnostics;
 import com.oracle.truffle.r.nodes.test.TestUtilities.NodeHandle;
 import com.oracle.truffle.r.nodes.unary.CastNode;
@@ -54,6 +57,7 @@ import com.oracle.truffle.r.runtime.RDeparse;
 import com.oracle.truffle.r.runtime.RError;
 import com.oracle.truffle.r.runtime.RSource;
 import com.oracle.truffle.r.runtime.ResourceHandlerFactory;
+import com.oracle.truffle.r.runtime.builtins.RBuiltinKind;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
 import com.oracle.truffle.r.runtime.data.RList;
 import com.oracle.truffle.r.runtime.data.RNull;
@@ -129,8 +133,12 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
         }
 
         @Override
-        public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinFactory bf) {
-            return new ChimneySweeping(this, bf);
+        public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinDiagFactory bf) {
+            if (bf instanceof RIntBuiltinDiagFactory) {
+                return new ChimneySweeping(this, (RIntBuiltinDiagFactory) bf);
+            } else {
+                throw new UnsupportedOperationException("Only non-external builtins supported for chimney-sweeping atm");
+            }
         }
 
         private static TestOutputManager loadTestOutputManager() throws IOException {
@@ -151,16 +159,18 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
     private final List<Samples<?>> argSamples;
     private final ChimneySweepingSuite diagSuite;
     private final Set<RList> validArgsList;
+    private final RBuiltinKind kind;
 
     private final Set<List<String>> printedOutputPairs = new HashSet<>();
     private final Set<String> printedErrors = new HashSet<>();
     private int sweepCounter = 0;
 
-    ChimneySweeping(ChimneySweepingSuite diagSuite, RBuiltinFactory builtinFactory) {
+    ChimneySweeping(ChimneySweepingSuite diagSuite, RIntBuiltinDiagFactory builtinFactory) {
         super(diagSuite, builtinFactory);
         this.diagSuite = diagSuite;
         this.validArgsList = extractValidArgsForBuiltin();
         this.argSamples = createSamples();
+        this.kind = builtinFactory.getBuiltinKind();
     }
 
     @Override
@@ -274,7 +284,7 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
 
         try {
             String snippetAnchor;
-            switch (annotation.kind()) {
+            switch (kind) {
                 case INTERNAL:
                     snippetAnchor = ".Internal(" + builtinName + "(";
                     break;
@@ -385,7 +395,7 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
             }
 
             String call;
-            switch (annotation.kind()) {
+            switch (kind) {
                 case INTERNAL:
                     call = ".Internal(" + builtinName + "(" + sb + "))";
                     break;
diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java
index 1e6844fbbc..e9079da634 100644
--- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java
+++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java
@@ -23,6 +23,7 @@
 package com.oracle.truffle.r.nodes.test;
 
 import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
 import java.lang.reflect.Type;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -30,6 +31,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -38,7 +40,7 @@ import com.oracle.truffle.api.frame.Frame;
 import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode;
 import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef;
 import com.oracle.truffle.r.nodes.builtin.RBuiltinFactory;
-import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
+import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
 import com.oracle.truffle.r.nodes.builtin.base.BasePackage;
 import com.oracle.truffle.r.nodes.casts.CastNodeSampler;
 import com.oracle.truffle.r.nodes.casts.CastUtils;
@@ -52,6 +54,7 @@ import com.oracle.truffle.r.nodes.test.ChimneySweeping.ChimneySweepingSuite;
 import com.oracle.truffle.r.nodes.unary.CastNode;
 import com.oracle.truffle.r.runtime.ArgumentsSignature;
 import com.oracle.truffle.r.runtime.builtins.RBuiltin;
+import com.oracle.truffle.r.runtime.builtins.RBuiltinKind;
 import com.oracle.truffle.r.runtime.data.RMissing;
 import com.oracle.truffle.r.runtime.data.RNull;
 import com.oracle.truffle.r.runtime.nodes.RNode;
@@ -98,19 +101,26 @@ public class RBuiltinDiagnostics {
         }
     }
 
-    public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinFactory bf) {
+    public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinDiagFactory bf) {
         return new SingleBuiltinDiagnostics(this, bf);
     }
 
     public void diagnoseSingleBuiltin(String builtinName) throws Exception {
         BasePackage bp = new BasePackage();
         RBuiltinFactory bf = bp.lookupByName(builtinName);
+        RBuiltinDiagFactory bdf;
         if (bf == null) {
-            System.out.println("No builtin '" + builtinName + "' found");
-            return;
+            try {
+                bdf = RExtBuiltinDiagFactory.create(builtinName);
+            } catch (Exception e) {
+                System.out.println("No builtin '" + builtinName + "' found");
+                return;
+            }
+        } else {
+            bdf = new RIntBuiltinDiagFactory(bf);
         }
 
-        createBuiltinDiagnostics(bf).diagnoseBuiltin();
+        createBuiltinDiagnostics(bdf).diagnoseBuiltin();
 
         System.out.println("Finished");
         System.out.println("--------");
@@ -122,7 +132,7 @@ public class RBuiltinDiagnostics {
         BasePackage bp = new BasePackage();
         for (RBuiltinFactory bf : bp.getBuiltins().values()) {
             try {
-                createBuiltinDiagnostics(bf).diagnoseBuiltin();
+                createBuiltinDiagnostics(new RIntBuiltinDiagFactory((bf))).diagnoseBuiltin();
             } catch (Exception e) {
                 System.out.println(bf.getName() + " failed: " + e.getMessage());
             }
@@ -136,27 +146,23 @@ public class RBuiltinDiagnostics {
 
     static class SingleBuiltinDiagnostics {
         private final RBuiltinDiagnostics diagSuite;
-        final RBuiltinFactory builtinFactory;
+        final RBuiltinDiagFactory builtinFactory;
         final String builtinName;
         final int argLength;
         final String[] parameterNames;
         final CastNode[] castNodes;
-        final Class<?> builtinClass;
-        final RBuiltin annotation;
         final List<Method> specMethods;
         final List<TypeExpr> argResultSets;
         final HashMap<Method, List<Set<Cast>>> convResultTypePerSpec;
         final Set<List<Type>> nonCoveredArgsSet;
 
-        SingleBuiltinDiagnostics(RBuiltinDiagnostics diagSuite, RBuiltinFactory builtinFactory) {
+        SingleBuiltinDiagnostics(RBuiltinDiagnostics diagSuite, RBuiltinDiagFactory builtinFactory) {
             this.diagSuite = diagSuite;
             this.builtinFactory = builtinFactory;
-            this.builtinName = builtinFactory.getName();
+            this.builtinName = builtinFactory.getBuiltinName();
 
-            this.builtinClass = builtinFactory.getBuiltinNodeClass();
-            this.annotation = builtinClass.getAnnotation(RBuiltin.class);
-            this.argLength = annotation.parameterNames().length;
-            String[] pn = annotation.parameterNames();
+            String[] pn = builtinFactory.getParameterNames();
+            this.argLength = pn.length;
             this.parameterNames = Arrays.stream(pn).map(n -> n.isEmpty() ? null : n).toArray(String[]::new);
 
             this.castNodes = getCastNodesFromBuiltin();
@@ -166,7 +172,7 @@ public class RBuiltinDiagnostics {
                 return !((diagSuite.diagConfig.ignoreRNull && t == RNull.class) || (diagSuite.diagConfig.ignoreRMissing && t == RMissing.class));
             })).collect(Collectors.toList());
 
-            this.specMethods = CastUtils.getAnnotatedMethods(builtinClass, Specialization.class);
+            this.specMethods = CastUtils.getAnnotatedMethods(builtinFactory.getBuiltinNodeClass(), Specialization.class);
 
             this.convResultTypePerSpec = createConvResultTypePerSpecialization();
             this.nonCoveredArgsSet = combineArguments();
@@ -225,7 +231,7 @@ public class RBuiltinDiagnostics {
 
         protected void diagnosePipeline(int i) {
             TypeExpr argResultSet = argResultSets.get(i);
-            System.out.println("\n Pipeline for '" + annotation.parameterNames()[i] + "' (arg[" + i + "]):");
+            System.out.println("\n Pipeline for '" + parameterNames[i] + "' (arg[" + i + "]):");
             System.out.println("  Result types union:");
             Set<Type> argSetNorm = argResultSet.normalize();
             System.out.println("   " + argSetNorm.stream().map(argType -> typeName(argType)).collect(Collectors.toSet()));
@@ -247,17 +253,7 @@ public class RBuiltinDiagnostics {
         }
 
         private CastNode[] getCastNodesFromBuiltin() {
-            ArgumentsSignature signature = ArgumentsSignature.get(parameterNames);
-
-            int total = signature.getLength();
-            RNode[] args = new RNode[total];
-            for (int i = 0; i < total; i++) {
-                args[i] = ReadVariableNode.create("dummy");
-            }
-            RBuiltinNode builtinNode = builtinFactory.getConstructor().apply(args.clone());
-
-            CastNode[] cn = builtinNode.getCasts();
-            return cn;
+            return builtinFactory.getCasts();
         }
 
         private List<TypeExpr> createArgResultSets() {
@@ -320,4 +316,123 @@ public class RBuiltinDiagnostics {
         }
         return typeName(m.getReturnType()) + " " + m.getName() + "(" + sb + ")";
     }
+
+    public interface RBuiltinDiagFactory {
+        String getBuiltinName();
+
+        Class<?> getBuiltinNodeClass();
+
+        String[] getParameterNames();
+
+        CastNode[] getCasts();
+
+    }
+
+    public static final class RIntBuiltinDiagFactory implements RBuiltinDiagFactory {
+
+        private final RBuiltinFactory fact;
+
+        public RIntBuiltinDiagFactory(RBuiltinFactory fact) {
+            super();
+            this.fact = fact;
+        }
+
+        @Override
+        public String getBuiltinName() {
+            return fact.getName();
+        }
+
+        @Override
+        public Class<?> getBuiltinNodeClass() {
+            return fact.getBuiltinNodeClass();
+        }
+
+        public RBuiltinKind getBuiltinKind() {
+            return fact.getKind();
+        }
+
+        @Override
+        public String[] getParameterNames() {
+            RBuiltin annotation = fact.getBuiltinNodeClass().getAnnotation(RBuiltin.class);
+            String[] pn = annotation.parameterNames();
+            return Arrays.stream(pn).map(n -> n.isEmpty() ? null : n).toArray(String[]::new);
+        }
+
+        @Override
+        public CastNode[] getCasts() {
+            ArgumentsSignature signature = ArgumentsSignature.get(getParameterNames());
+
+            int total = signature.getLength();
+            RNode[] args = new RNode[total];
+            for (int i = 0; i < total; i++) {
+                args[i] = ReadVariableNode.create("dummy");
+            }
+
+            return fact.getConstructor().apply(args).getCasts();
+        }
+
+    }
+
+    public static final class RExtBuiltinDiagFactory implements RBuiltinDiagFactory {
+
+        private final Class<? extends RExternalBuiltinNode> nodeClass;
+        private final String[] parameterNames;
+
+        RExtBuiltinDiagFactory(Class<? extends RExternalBuiltinNode> nodeClass, int arity) {
+            this.nodeClass = nodeClass;
+            this.parameterNames = new String[arity];
+            for (int i = 0; i < arity; i++) {
+                this.parameterNames[i] = "arg" + i;
+            }
+        }
+
+        @SuppressWarnings("unchecked")
+        public static RExtBuiltinDiagFactory create(String extBuiltinClsName) throws ClassNotFoundException {
+            Class<?> nodeClass = Class.forName(extBuiltinClsName);
+
+            if (!Modifier.isFinal(nodeClass.getModifiers())) {
+                nodeClass = Class.forName(extBuiltinClsName + "NodeGen");
+                if (!Modifier.isFinal(nodeClass.getModifiers())) {
+                    throw new IllegalArgumentException("Invalid external builtin class name: " + extBuiltinClsName);
+                }
+            }
+
+            if (!RExternalBuiltinNode.class.isAssignableFrom(nodeClass)) {
+                throw new IllegalArgumentException(extBuiltinClsName + " is not a subclass of " + RExternalBuiltinNode.class.getName());
+            }
+
+            Optional<Method> execMethod = Arrays.stream(nodeClass.getMethods()).filter(
+                            m -> m.getName().equals("execute") && Arrays.stream(m.getParameterTypes()).allMatch(t -> t == Object.class)).findFirst();
+            if (execMethod.isPresent()) {
+                return new RExtBuiltinDiagFactory((Class<RExternalBuiltinNode>) nodeClass, execMethod.get().getParameterCount());
+            } else {
+                throw new UnsupportedOperationException(extBuiltinClsName + " is not a supported external builtin class");
+            }
+        }
+
+        @Override
+        public String getBuiltinName() {
+            return nodeClass.getSimpleName();
+        }
+
+        @Override
+        public Class<?> getBuiltinNodeClass() {
+            return nodeClass;
+        }
+
+        @Override
+        public String[] getParameterNames() {
+            return parameterNames;
+        }
+
+        @Override
+        public CastNode[] getCasts() {
+            try {
+                return ((RExternalBuiltinNode) nodeClass.getMethod("create").invoke(null)).getCasts();
+            } catch (Exception e) {
+                throw new RuntimeException(e);
+            }
+        }
+    }
+
 }
diff --git a/mx.fastr/mx_fastr.py b/mx.fastr/mx_fastr.py
index a32ba1f9d5..e329ea7149 100644
--- a/mx.fastr/mx_fastr.py
+++ b/mx.fastr/mx_fastr.py
@@ -449,6 +449,7 @@ def rbdiag(args):
     --sweep-total	Performs the 'chimney-sweeping'. The total sample selection method is used.
 
 	If no builtin is specified, all registered builtins are diagnosed.
+	An external builtin is specified by the fully qualified name of its node class.
 
 	Examples:
 
@@ -456,6 +457,7 @@ def rbdiag(args):
 		mx rbdiag colSums colMeans -v
 		mx rbdiag scan -m -n
     	mx rbdiag colSums --sweep
+    	mx rbdiag com.oracle.truffle.r.library.stats.Rnorm
     '''
     cp = mx.classpath('com.oracle.truffle.r.nodes.test')
 
-- 
GitLab