Skip to content
Snippets Groups Projects
Commit 098c1137 authored by Zbynek Slajchrt's avatar Zbynek Slajchrt
Browse files

rbdiag supports external builtins (no chimney-sweeping atm)

parent 1b5268e8
No related branches found
No related tags found
No related merge requests found
......@@ -44,9 +44,12 @@ import com.oracle.truffle.api.dsl.UnsupportedSpecializationException;
import com.oracle.truffle.api.vm.PolyglotEngine;
import com.oracle.truffle.api.vm.PolyglotEngine.Value;
import com.oracle.truffle.r.nodes.builtin.RBuiltinFactory;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.casts.CastNodeSampler;
import com.oracle.truffle.r.nodes.casts.Samples;
import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.DiagConfig;
import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.RBuiltinDiagFactory;
import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.RIntBuiltinDiagFactory;
import com.oracle.truffle.r.nodes.test.RBuiltinDiagnostics.SingleBuiltinDiagnostics;
import com.oracle.truffle.r.nodes.test.TestUtilities.NodeHandle;
import com.oracle.truffle.r.nodes.unary.CastNode;
......@@ -54,6 +57,7 @@ import com.oracle.truffle.r.runtime.RDeparse;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RSource;
import com.oracle.truffle.r.runtime.ResourceHandlerFactory;
import com.oracle.truffle.r.runtime.builtins.RBuiltinKind;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RNull;
......@@ -129,8 +133,12 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
}
@Override
public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinFactory bf) {
return new ChimneySweeping(this, bf);
public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinDiagFactory bf) {
if (bf instanceof RIntBuiltinDiagFactory) {
return new ChimneySweeping(this, (RIntBuiltinDiagFactory) bf);
} else {
throw new UnsupportedOperationException("Only non-external builtins supported for chimney-sweeping atm");
}
}
private static TestOutputManager loadTestOutputManager() throws IOException {
......@@ -151,16 +159,18 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
private final List<Samples<?>> argSamples;
private final ChimneySweepingSuite diagSuite;
private final Set<RList> validArgsList;
private final RBuiltinKind kind;
private final Set<List<String>> printedOutputPairs = new HashSet<>();
private final Set<String> printedErrors = new HashSet<>();
private int sweepCounter = 0;
ChimneySweeping(ChimneySweepingSuite diagSuite, RBuiltinFactory builtinFactory) {
ChimneySweeping(ChimneySweepingSuite diagSuite, RIntBuiltinDiagFactory builtinFactory) {
super(diagSuite, builtinFactory);
this.diagSuite = diagSuite;
this.validArgsList = extractValidArgsForBuiltin();
this.argSamples = createSamples();
this.kind = builtinFactory.getBuiltinKind();
}
@Override
......@@ -274,7 +284,7 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
try {
String snippetAnchor;
switch (annotation.kind()) {
switch (kind) {
case INTERNAL:
snippetAnchor = ".Internal(" + builtinName + "(";
break;
......@@ -385,7 +395,7 @@ class ChimneySweeping extends SingleBuiltinDiagnostics {
}
String call;
switch (annotation.kind()) {
switch (kind) {
case INTERNAL:
call = ".Internal(" + builtinName + "(" + sb + "))";
break;
......
......@@ -23,6 +23,7 @@
package com.oracle.truffle.r.nodes.test;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
......@@ -30,6 +31,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
......@@ -38,7 +40,7 @@ import com.oracle.truffle.api.frame.Frame;
import com.oracle.truffle.r.nodes.access.variables.ReadVariableNode;
import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef;
import com.oracle.truffle.r.nodes.builtin.RBuiltinFactory;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
import com.oracle.truffle.r.nodes.builtin.base.BasePackage;
import com.oracle.truffle.r.nodes.casts.CastNodeSampler;
import com.oracle.truffle.r.nodes.casts.CastUtils;
......@@ -52,6 +54,7 @@ import com.oracle.truffle.r.nodes.test.ChimneySweeping.ChimneySweepingSuite;
import com.oracle.truffle.r.nodes.unary.CastNode;
import com.oracle.truffle.r.runtime.ArgumentsSignature;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.builtins.RBuiltinKind;
import com.oracle.truffle.r.runtime.data.RMissing;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.nodes.RNode;
......@@ -98,19 +101,26 @@ public class RBuiltinDiagnostics {
}
}
public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinFactory bf) {
public SingleBuiltinDiagnostics createBuiltinDiagnostics(RBuiltinDiagFactory bf) {
return new SingleBuiltinDiagnostics(this, bf);
}
public void diagnoseSingleBuiltin(String builtinName) throws Exception {
BasePackage bp = new BasePackage();
RBuiltinFactory bf = bp.lookupByName(builtinName);
RBuiltinDiagFactory bdf;
if (bf == null) {
System.out.println("No builtin '" + builtinName + "' found");
return;
try {
bdf = RExtBuiltinDiagFactory.create(builtinName);
} catch (Exception e) {
System.out.println("No builtin '" + builtinName + "' found");
return;
}
} else {
bdf = new RIntBuiltinDiagFactory(bf);
}
createBuiltinDiagnostics(bf).diagnoseBuiltin();
createBuiltinDiagnostics(bdf).diagnoseBuiltin();
System.out.println("Finished");
System.out.println("--------");
......@@ -122,7 +132,7 @@ public class RBuiltinDiagnostics {
BasePackage bp = new BasePackage();
for (RBuiltinFactory bf : bp.getBuiltins().values()) {
try {
createBuiltinDiagnostics(bf).diagnoseBuiltin();
createBuiltinDiagnostics(new RIntBuiltinDiagFactory((bf))).diagnoseBuiltin();
} catch (Exception e) {
System.out.println(bf.getName() + " failed: " + e.getMessage());
}
......@@ -136,27 +146,23 @@ public class RBuiltinDiagnostics {
static class SingleBuiltinDiagnostics {
private final RBuiltinDiagnostics diagSuite;
final RBuiltinFactory builtinFactory;
final RBuiltinDiagFactory builtinFactory;
final String builtinName;
final int argLength;
final String[] parameterNames;
final CastNode[] castNodes;
final Class<?> builtinClass;
final RBuiltin annotation;
final List<Method> specMethods;
final List<TypeExpr> argResultSets;
final HashMap<Method, List<Set<Cast>>> convResultTypePerSpec;
final Set<List<Type>> nonCoveredArgsSet;
SingleBuiltinDiagnostics(RBuiltinDiagnostics diagSuite, RBuiltinFactory builtinFactory) {
SingleBuiltinDiagnostics(RBuiltinDiagnostics diagSuite, RBuiltinDiagFactory builtinFactory) {
this.diagSuite = diagSuite;
this.builtinFactory = builtinFactory;
this.builtinName = builtinFactory.getName();
this.builtinName = builtinFactory.getBuiltinName();
this.builtinClass = builtinFactory.getBuiltinNodeClass();
this.annotation = builtinClass.getAnnotation(RBuiltin.class);
this.argLength = annotation.parameterNames().length;
String[] pn = annotation.parameterNames();
String[] pn = builtinFactory.getParameterNames();
this.argLength = pn.length;
this.parameterNames = Arrays.stream(pn).map(n -> n.isEmpty() ? null : n).toArray(String[]::new);
this.castNodes = getCastNodesFromBuiltin();
......@@ -166,7 +172,7 @@ public class RBuiltinDiagnostics {
return !((diagSuite.diagConfig.ignoreRNull && t == RNull.class) || (diagSuite.diagConfig.ignoreRMissing && t == RMissing.class));
})).collect(Collectors.toList());
this.specMethods = CastUtils.getAnnotatedMethods(builtinClass, Specialization.class);
this.specMethods = CastUtils.getAnnotatedMethods(builtinFactory.getBuiltinNodeClass(), Specialization.class);
this.convResultTypePerSpec = createConvResultTypePerSpecialization();
this.nonCoveredArgsSet = combineArguments();
......@@ -225,7 +231,7 @@ public class RBuiltinDiagnostics {
protected void diagnosePipeline(int i) {
TypeExpr argResultSet = argResultSets.get(i);
System.out.println("\n Pipeline for '" + annotation.parameterNames()[i] + "' (arg[" + i + "]):");
System.out.println("\n Pipeline for '" + parameterNames[i] + "' (arg[" + i + "]):");
System.out.println(" Result types union:");
Set<Type> argSetNorm = argResultSet.normalize();
System.out.println(" " + argSetNorm.stream().map(argType -> typeName(argType)).collect(Collectors.toSet()));
......@@ -247,17 +253,7 @@ public class RBuiltinDiagnostics {
}
private CastNode[] getCastNodesFromBuiltin() {
ArgumentsSignature signature = ArgumentsSignature.get(parameterNames);
int total = signature.getLength();
RNode[] args = new RNode[total];
for (int i = 0; i < total; i++) {
args[i] = ReadVariableNode.create("dummy");
}
RBuiltinNode builtinNode = builtinFactory.getConstructor().apply(args.clone());
CastNode[] cn = builtinNode.getCasts();
return cn;
return builtinFactory.getCasts();
}
private List<TypeExpr> createArgResultSets() {
......@@ -320,4 +316,123 @@ public class RBuiltinDiagnostics {
}
return typeName(m.getReturnType()) + " " + m.getName() + "(" + sb + ")";
}
public interface RBuiltinDiagFactory {
String getBuiltinName();
Class<?> getBuiltinNodeClass();
String[] getParameterNames();
CastNode[] getCasts();
}
public static final class RIntBuiltinDiagFactory implements RBuiltinDiagFactory {
private final RBuiltinFactory fact;
public RIntBuiltinDiagFactory(RBuiltinFactory fact) {
super();
this.fact = fact;
}
@Override
public String getBuiltinName() {
return fact.getName();
}
@Override
public Class<?> getBuiltinNodeClass() {
return fact.getBuiltinNodeClass();
}
public RBuiltinKind getBuiltinKind() {
return fact.getKind();
}
@Override
public String[] getParameterNames() {
RBuiltin annotation = fact.getBuiltinNodeClass().getAnnotation(RBuiltin.class);
String[] pn = annotation.parameterNames();
return Arrays.stream(pn).map(n -> n.isEmpty() ? null : n).toArray(String[]::new);
}
@Override
public CastNode[] getCasts() {
ArgumentsSignature signature = ArgumentsSignature.get(getParameterNames());
int total = signature.getLength();
RNode[] args = new RNode[total];
for (int i = 0; i < total; i++) {
args[i] = ReadVariableNode.create("dummy");
}
return fact.getConstructor().apply(args).getCasts();
}
}
public static final class RExtBuiltinDiagFactory implements RBuiltinDiagFactory {
private final Class<? extends RExternalBuiltinNode> nodeClass;
private final String[] parameterNames;
RExtBuiltinDiagFactory(Class<? extends RExternalBuiltinNode> nodeClass, int arity) {
this.nodeClass = nodeClass;
this.parameterNames = new String[arity];
for (int i = 0; i < arity; i++) {
this.parameterNames[i] = "arg" + i;
}
}
@SuppressWarnings("unchecked")
public static RExtBuiltinDiagFactory create(String extBuiltinClsName) throws ClassNotFoundException {
Class<?> nodeClass = Class.forName(extBuiltinClsName);
if (!Modifier.isFinal(nodeClass.getModifiers())) {
nodeClass = Class.forName(extBuiltinClsName + "NodeGen");
if (!Modifier.isFinal(nodeClass.getModifiers())) {
throw new IllegalArgumentException("Invalid external builtin class name: " + extBuiltinClsName);
}
}
if (!RExternalBuiltinNode.class.isAssignableFrom(nodeClass)) {
throw new IllegalArgumentException(extBuiltinClsName + " is not a subclass of " + RExternalBuiltinNode.class.getName());
}
Optional<Method> execMethod = Arrays.stream(nodeClass.getMethods()).filter(
m -> m.getName().equals("execute") && Arrays.stream(m.getParameterTypes()).allMatch(t -> t == Object.class)).findFirst();
if (execMethod.isPresent()) {
return new RExtBuiltinDiagFactory((Class<RExternalBuiltinNode>) nodeClass, execMethod.get().getParameterCount());
} else {
throw new UnsupportedOperationException(extBuiltinClsName + " is not a supported external builtin class");
}
}
@Override
public String getBuiltinName() {
return nodeClass.getSimpleName();
}
@Override
public Class<?> getBuiltinNodeClass() {
return nodeClass;
}
@Override
public String[] getParameterNames() {
return parameterNames;
}
@Override
public CastNode[] getCasts() {
try {
return ((RExternalBuiltinNode) nodeClass.getMethod("create").invoke(null)).getCasts();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}
......@@ -449,6 +449,7 @@ def rbdiag(args):
--sweep-total Performs the 'chimney-sweeping'. The total sample selection method is used.
If no builtin is specified, all registered builtins are diagnosed.
An external builtin is specified by the fully qualified name of its node class.
Examples:
......@@ -456,6 +457,7 @@ def rbdiag(args):
mx rbdiag colSums colMeans -v
mx rbdiag scan -m -n
mx rbdiag colSums --sweep
mx rbdiag com.oracle.truffle.r.library.stats.Rnorm
'''
cp = mx.classpath('com.oracle.truffle.r.nodes.test')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment