From 098c1137b06ebe26b37f69888b0583e4224e121d Mon Sep 17 00:00:00 2001 From: Zbynek Slajchrt <zbynek.slajchrt@oracle.com> Date: Fri, 9 Sep 2016 11:50:25 +0200 Subject: [PATCH] rbdiag supports external builtins (no chimney-sweeping atm) --- .../truffle/r/nodes/test/ChimneySweeping.java | 20 +- .../r/nodes/test/RBuiltinDiagnostics.java | 171 +++++++++++++++--- mx.fastr/mx_fastr.py | 2 + 3 files changed, 160 insertions(+), 33 deletions(-) diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java index 3c1dd496b8..57c5f03a4f 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/ChimneySweeping.java @@ -44,9 +44,12 @@ import com.oracle.truffle.api.dsl.UnsupportedSpecializationException; import com.oracle.truffle.api.vm.PolyglotEngine; import com.oracle.truffle.api.vm.PolyglotEngine.Value; import com.oracle.truffle.r.nodes.builtin.RBuiltinFactory; +import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.casts.CastNodeSampler; import com.oracle.truffle.r.nodes.casts.Samples; import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.DiagConfig; +import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.RBuiltinDiagFactory; +import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.RIntBuiltinDiagFactory; import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.SingleBuiltinDiagnostics; import com.oracle.truffle.r.nodes.test.TestUtilities.NodeHandle; import com.oracle.truffle.r.nodes.unary.CastNode; @@ -54,6 +57,7 @@ import com.oracle.truffle.r.runtime.RDeparse; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RSource; import com.oracle.truffle.r.runtime.ResourceHandlerFactory; +import com.oracle.truffle.r.runtime.builtins.RBuiltinKind; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RNull; @@ -129,8 +133,12 @@ class ChimneySweeping extends SingleBuiltinDiagnostics { } @Override - public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinFactory bf) { - return new ChimneySweeping(this, bf); + public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinDiagFactory bf) { + if (bf instanceof RIntBuiltinDiagFactory) { + return new ChimneySweeping(this, (RIntBuiltinDiagFactory) bf); + } else { + throw new UnsupportedOperationException("Only non-external builtins supported for chimney-sweeping atm"); + } } private static TestOutputManager loadTestOutputManager() throws IOException { @@ -151,16 +159,18 @@ class ChimneySweeping extends SingleBuiltinDiagnostics { private final List<Samples<?>> argSamples; private final ChimneySweepingSuite diagSuite; private final Set<RList> validArgsList; + private final RBuiltinKind kind; private final Set<List<String>> printedOutputPairs = new HashSet<>(); private final Set<String> printedErrors = new HashSet<>(); private int sweepCounter = 0; - ChimneySweeping(ChimneySweepingSuite diagSuite, RBuiltinFactory builtinFactory) { + ChimneySweeping(ChimneySweepingSuite diagSuite, RIntBuiltinDiagFactory builtinFactory) { super(diagSuite, builtinFactory); this.diagSuite = diagSuite; this.validArgsList = extractValidArgsForBuiltin(); this.argSamples = createSamples(); + this.kind = builtinFactory.getBuiltinKind(); } @Override @@ -274,7 +284,7 @@ class ChimneySweeping extends SingleBuiltinDiagnostics { try { String snippetAnchor; - switch (annotation.kind()) { + switch (kind) { case INTERNAL: snippetAnchor = ".Internal(" + builtinName + "("; break; @@ -385,7 +395,7 @@ class ChimneySweeping extends SingleBuiltinDiagnostics { } String call; - switch (annotation.kind()) { + switch (kind) { case INTERNAL: call = ".Internal(" + builtinName + "(" + sb + "))"; break; diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java index 1e6844fbbc..e9079da634 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java @@ -23,6 +23,7 @@ package com.oracle.truffle.r.nodes.test; import java.lang.reflect.Method; +import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.ArrayList; import java.util.Arrays; @@ -30,6 +31,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -38,7 +40,7 @@ import com.oracle.truffle.api.frame.Frame; import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef; import com.oracle.truffle.r.nodes.builtin.RBuiltinFactory; -import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.builtin.base.BasePackage; import com.oracle.truffle.r.nodes.casts.CastNodeSampler; import com.oracle.truffle.r.nodes.casts.CastUtils; @@ -52,6 +54,7 @@ import com.oracle.truffle.r.nodes.test.ChimneySweeping.ChimneySweepingSuite; import com.oracle.truffle.r.nodes.unary.CastNode; import com.oracle.truffle.r.runtime.ArgumentsSignature; import com.oracle.truffle.r.runtime.builtins.RBuiltin; +import com.oracle.truffle.r.runtime.builtins.RBuiltinKind; import com.oracle.truffle.r.runtime.data.RMissing; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.nodes.RNode; @@ -98,19 +101,26 @@ public class RBuiltinDiagnostics { } } - public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinFactory bf) { + public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinDiagFactory bf) { return new SingleBuiltinDiagnostics(this, bf); } public void diagnoseSingleBuiltin(String builtinName) throws Exception { BasePackage bp = new BasePackage(); RBuiltinFactory bf = bp.lookupByName(builtinName); + RBuiltinDiagFactory bdf; if (bf == null) { - System.out.println("No builtin '" + builtinName + "' found"); - return; + try { + bdf = RExtBuiltinDiagFactory.create(builtinName); + } catch (Exception e) { + System.out.println("No builtin '" + builtinName + "' found"); + return; + } + } else { + bdf = new RIntBuiltinDiagFactory(bf); } - createBuiltinDiagnostics(bf).diagnoseBuiltin(); + createBuiltinDiagnostics(bdf).diagnoseBuiltin(); System.out.println("Finished"); System.out.println("--------"); @@ -122,7 +132,7 @@ public class RBuiltinDiagnostics { BasePackage bp = new BasePackage(); for (RBuiltinFactory bf : bp.getBuiltins().values()) { try { - createBuiltinDiagnostics(bf).diagnoseBuiltin(); + createBuiltinDiagnostics(new RIntBuiltinDiagFactory((bf))).diagnoseBuiltin(); } catch (Exception e) { System.out.println(bf.getName() + " failed: " + e.getMessage()); } @@ -136,27 +146,23 @@ public class RBuiltinDiagnostics { static class SingleBuiltinDiagnostics { private final RBuiltinDiagnostics diagSuite; - final RBuiltinFactory builtinFactory; + final RBuiltinDiagFactory builtinFactory; final String builtinName; final int argLength; final String[] parameterNames; final CastNode[] castNodes; - final Class<?> builtinClass; - final RBuiltin annotation; final List<Method> specMethods; final List<TypeExpr> argResultSets; final HashMap<Method, List<Set<Cast>>> convResultTypePerSpec; final Set<List<Type>> nonCoveredArgsSet; - SingleBuiltinDiagnostics(RBuiltinDiagnostics diagSuite, RBuiltinFactory builtinFactory) { + SingleBuiltinDiagnostics(RBuiltinDiagnostics diagSuite, RBuiltinDiagFactory builtinFactory) { this.diagSuite = diagSuite; this.builtinFactory = builtinFactory; - this.builtinName = builtinFactory.getName(); + this.builtinName = builtinFactory.getBuiltinName(); - this.builtinClass = builtinFactory.getBuiltinNodeClass(); - this.annotation = builtinClass.getAnnotation(RBuiltin.class); - this.argLength = annotation.parameterNames().length; - String[] pn = annotation.parameterNames(); + String[] pn = builtinFactory.getParameterNames(); + this.argLength = pn.length; this.parameterNames = Arrays.stream(pn).map(n -> n.isEmpty() ? null : n).toArray(String[]::new); this.castNodes = getCastNodesFromBuiltin(); @@ -166,7 +172,7 @@ public class RBuiltinDiagnostics { return !((diagSuite.diagConfig.ignoreRNull && t == RNull.class) || (diagSuite.diagConfig.ignoreRMissing && t == RMissing.class)); })).collect(Collectors.toList()); - this.specMethods = CastUtils.getAnnotatedMethods(builtinClass, Specialization.class); + this.specMethods = CastUtils.getAnnotatedMethods(builtinFactory.getBuiltinNodeClass(), Specialization.class); this.convResultTypePerSpec = createConvResultTypePerSpecialization(); this.nonCoveredArgsSet = combineArguments(); @@ -225,7 +231,7 @@ public class RBuiltinDiagnostics { protected void diagnosePipeline(int i) { TypeExpr argResultSet = argResultSets.get(i); - System.out.println("\n Pipeline for '" + annotation.parameterNames()[i] + "' (arg[" + i + "]):"); + System.out.println("\n Pipeline for '" + parameterNames[i] + "' (arg[" + i + "]):"); System.out.println(" Result types union:"); Set<Type> argSetNorm = argResultSet.normalize(); System.out.println(" " + argSetNorm.stream().map(argType -> typeName(argType)).collect(Collectors.toSet())); @@ -247,17 +253,7 @@ public class RBuiltinDiagnostics { } private CastNode[] getCastNodesFromBuiltin() { - ArgumentsSignature signature = ArgumentsSignature.get(parameterNames); - - int total = signature.getLength(); - RNode[] args = new RNode[total]; - for (int i = 0; i < total; i++) { - args[i] = ReadVariableNode.create("dummy"); - } - RBuiltinNode builtinNode = builtinFactory.getConstructor().apply(args.clone()); - - CastNode[] cn = builtinNode.getCasts(); - return cn; + return builtinFactory.getCasts(); } private List<TypeExpr> createArgResultSets() { @@ -320,4 +316,123 @@ public class RBuiltinDiagnostics { } return typeName(m.getReturnType()) + " " + m.getName() + "(" + sb + ")"; } + + public interface RBuiltinDiagFactory { + String getBuiltinName(); + + Class<?> getBuiltinNodeClass(); + + String[] getParameterNames(); + + CastNode[] getCasts(); + + } + + public static final class RIntBuiltinDiagFactory implements RBuiltinDiagFactory { + + private final RBuiltinFactory fact; + + public RIntBuiltinDiagFactory(RBuiltinFactory fact) { + super(); + this.fact = fact; + } + + @Override + public String getBuiltinName() { + return fact.getName(); + } + + @Override + public Class<?> getBuiltinNodeClass() { + return fact.getBuiltinNodeClass(); + } + + public RBuiltinKind getBuiltinKind() { + return fact.getKind(); + } + + @Override + public String[] getParameterNames() { + RBuiltin annotation = fact.getBuiltinNodeClass().getAnnotation(RBuiltin.class); + String[] pn = annotation.parameterNames(); + return Arrays.stream(pn).map(n -> n.isEmpty() ? null : n).toArray(String[]::new); + } + + @Override + public CastNode[] getCasts() { + ArgumentsSignature signature = ArgumentsSignature.get(getParameterNames()); + + int total = signature.getLength(); + RNode[] args = new RNode[total]; + for (int i = 0; i < total; i++) { + args[i] = ReadVariableNode.create("dummy"); + } + + return fact.getConstructor().apply(args).getCasts(); + } + + } + + public static final class RExtBuiltinDiagFactory implements RBuiltinDiagFactory { + + private final Class<? extends RExternalBuiltinNode> nodeClass; + private final String[] parameterNames; + + RExtBuiltinDiagFactory(Class<? extends RExternalBuiltinNode> nodeClass, int arity) { + this.nodeClass = nodeClass; + this.parameterNames = new String[arity]; + for (int i = 0; i < arity; i++) { + this.parameterNames[i] = "arg" + i; + } + } + + @SuppressWarnings("unchecked") + public static RExtBuiltinDiagFactory create(String extBuiltinClsName) throws ClassNotFoundException { + Class<?> nodeClass = Class.forName(extBuiltinClsName); + + if (!Modifier.isFinal(nodeClass.getModifiers())) { + nodeClass = Class.forName(extBuiltinClsName + "NodeGen"); + if (!Modifier.isFinal(nodeClass.getModifiers())) { + throw new IllegalArgumentException("Invalid external builtin class name: " + extBuiltinClsName); + } + } + + if (!RExternalBuiltinNode.class.isAssignableFrom(nodeClass)) { + throw new IllegalArgumentException(extBuiltinClsName + " is not a subclass of " + RExternalBuiltinNode.class.getName()); + } + + Optional<Method> execMethod = Arrays.stream(nodeClass.getMethods()).filter( + m -> m.getName().equals("execute") && Arrays.stream(m.getParameterTypes()).allMatch(t -> t == Object.class)).findFirst(); + if (execMethod.isPresent()) { + return new RExtBuiltinDiagFactory((Class<RExternalBuiltinNode>) nodeClass, execMethod.get().getParameterCount()); + } else { + throw new UnsupportedOperationException(extBuiltinClsName + " is not a supported external builtin class"); + } + } + + @Override + public String getBuiltinName() { + return nodeClass.getSimpleName(); + } + + @Override + public Class<?> getBuiltinNodeClass() { + return nodeClass; + } + + @Override + public String[] getParameterNames() { + return parameterNames; + } + + @Override + public CastNode[] getCasts() { + try { + return ((RExternalBuiltinNode) nodeClass.getMethod("create").invoke(null)).getCasts(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + } diff --git a/mx.fastr/mx_fastr.py b/mx.fastr/mx_fastr.py index a32ba1f9d5..e329ea7149 100644 --- a/mx.fastr/mx_fastr.py +++ b/mx.fastr/mx_fastr.py @@ -449,6 +449,7 @@ def rbdiag(args): --sweep-total Performs the 'chimney-sweeping'. The total sample selection method is used. If no builtin is specified, all registered builtins are diagnosed. + An external builtin is specified by the fully qualified name of its node class. Examples: @@ -456,6 +457,7 @@ def rbdiag(args): mx rbdiag colSums colMeans -v mx rbdiag scan -m -n mx rbdiag colSums --sweep + mx rbdiag com.oracle.truffle.r.library.stats.Rnorm ''' cp = mx.classpath('com.oracle.truffle.r.nodes.test') -- GitLab