diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java index 8e06ec87346509fc8cf0215c6e851c3ac1b49649..d22ab4cd4cb6915f75b410e1718b98bdd5769f33 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java @@ -433,6 +433,12 @@ public class ResultTypesAnalyserTest { assertTypes(atom(String.class).lower(m("m2")).or(atom(Integer.class).lower(m("m1")))); } + @Test + public void testAllowMissing() { + arg.allowMissing().mustBe(stringValue()); + assertTypes(RMissing.class, String.class, RAbstractStringVector.class); + } + @Test public void testTwoWildcardTypes() { arg.mustBe((instanceOf(String.class).and(mark(length(10), "l10").or(mark(length(20), "l20"))))); diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java index 8dcc52367a6926d1a90d926fd5638a370c338874..ba943022bf43323100d13f7007016a6021ecec78 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java @@ -131,6 +131,8 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple } + private static final TypeExpr NOT_NULL_NOT_MISSING = atom(RNull.class).not().and(atom(RMissing.class).not()); + public static TypeExpr analyse(PipelineStep<?, ?> firstStep) { return analyse(firstStep, AltTypeExpr.create()).merge(); } @@ -222,8 +224,7 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple if (visitTrueBranch) { if (step.isReturns()) { AltTypeExpr returnedType = trueBranchResultTypes(step, inputType, filterRes); - inputType.addAlt(returnedType.merge()); - return inputType; + return inputType.addAlt(returnedType.merge()); } else { return trueBranchResultTypes(step, inputType, filterRes); } @@ -326,7 +327,7 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple if (filter.getType2() != null) { resTp = resTp.or(atom(filter.getType2())); } - return resTp.and(atom(RNull.class).not().and(atom(RMissing.class).not())); + return resTp.and(NOT_NULL_NOT_MISSING); } @Override @@ -377,7 +378,7 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple @Override public TypeExpr visit(MatrixFilter<?> filter, TypeExpr previous) { - return previous.lower(filter); + return previous.lower(filter).and(NOT_NULL_NOT_MISSING); } @Override @@ -437,17 +438,17 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple @Override public TypeExpr visit(VectorSize vectorSize, byte operation, TypeExpr previous) { - return previous; + return previous.and(NOT_NULL_NOT_MISSING); } @Override public TypeExpr visit(ElementAt elementAt, byte operation, TypeExpr previous) { - return previous; + return previous.and(NOT_NULL_NOT_MISSING); } @Override public TypeExpr visit(Dim dim, byte operation, TypeExpr previous) { - return previous; + return previous.and(NOT_NULL_NOT_MISSING); } } diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java index 3ec99e591e05e2bdfe9661f1c18d17e1ad01901e..c4f1c1b1027b5e5ecf2f0b8896461261a3c1959b 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/test/RBuiltinDiagnostics.java @@ -35,8 +35,10 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Predicate; import java.util.stream.Collectors; +import com.oracle.truffle.api.dsl.GeneratedBy; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.frame.Frame; import com.oracle.truffle.r.nodes.builtin.NodeWithArgumentCasts; @@ -57,6 +59,7 @@ import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltinKind; import com.oracle.truffle.r.runtime.data.RMissing; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; public class RBuiltinDiagnostics { @@ -278,24 +281,29 @@ public class RBuiltinDiagnostics { return sb.toString(); } + private static String toGenNodeName(String name) { + if (name.endsWith("Node")) { + return name + "Gen"; + } else { + return name + "NodeGen"; + } + } + private static Class<?> toNodeGenClass(Class<?> nodeCls) throws ClassNotFoundException { String nodeGenClsName; if (nodeCls.getEnclosingClass() == null) { - nodeGenClsName = nodeCls.getName() + "NodeGen"; + nodeGenClsName = toGenNodeName(nodeCls.getName()); } else { + String enclClsName = nodeCls.getEnclosingClass().getName(); + String enclosingClsSuffix = RBaseNode.class.isAssignableFrom(nodeCls.getEnclosingClass()) ? (enclClsName.endsWith("Node") ? "Gen" : "NodeGen") : "Factory"; String[] split = nodeCls.getName().split("\\."); StringBuilder sb = new StringBuilder(); for (int i = 0; i < split.length; i++) { String s = split[i]; if (i == split.length - 1) { String[] lastSplit = s.split("\\$"); - sb.append(lastSplit[0] + "Factory$"); - sb.append(lastSplit[1]); - if (s.endsWith("Node")) { - sb.append("Gen"); - } else { - sb.append("NodeGen"); - } + sb.append(lastSplit[0] + enclosingClsSuffix + "$"); + sb.append(toGenNodeName(lastSplit[1])); } else { sb.append(s); } @@ -360,7 +368,15 @@ public class RBuiltinDiagnostics { argResultSets = createArgResultSets(); - this.specMethods = CastUtils.getAnnotatedMethods(builtinFactory.getBuiltinNodeClass(), Specialization.class); + List<Method> specs = CastUtils.getAnnotatedMethods(builtinFactory.getBuiltinNodeClass(), Specialization.class); + this.specMethods = new ArrayList<>(specs); + // N.B. The fallback method cannot be found by the Fallback annotation since + // this annotation has the CLASS retention policy. Nonetheless, the fallback method can + // be determined throught the fallback node in the generated class. + Optional<Method> fallback = findFallbackMethod(toNodeGenClass(bltnCls)); + if (fallback.isPresent()) { + this.specMethods.add(fallback.get()); + } this.convResultTypePerSpec = createConvResultTypePerSpecialization(); this.nonCoveredArgsSet = combineArguments(); @@ -581,6 +597,34 @@ public class RBuiltinDiagnostics { return typeName(m.getReturnType()) + " " + m.getName() + "(" + sb + ")"; } + private static Optional<Method> findFallbackMethod(Class<?> genBltnClass) { + Optional<Class<?>> fallbackNodeCls = Arrays.stream(genBltnClass.getDeclaredClasses()).filter(c -> "FallbackNode_".equals(c.getSimpleName())).findFirst(); + return fallbackNodeCls.flatMap(fc -> findFallbackMethodFromAnnot(fc, genBltnClass.getSuperclass())); + } + + private static Optional<Method> findFallbackMethodFromAnnot(Class<?> fallbackNodeClass, Class<?> bltnCls) { + GeneratedBy genByAnnot = fallbackNodeClass.getAnnotation(GeneratedBy.class); + assert genByAnnot != null; + String fallbackName = genByAnnot.methodName(); + + return findMethod(bltnCls, dm -> { + return dm.getAnnotation(Specialization.class) == null && fallbackName.startsWith(dm.getName() + "("); + }); + } + + private static Optional<Method> findMethod(Class<?> clazz, Predicate<Method> filter) { + Optional<Method> res = Arrays.asList(clazz.getDeclaredMethods()).stream().filter(filter).findFirst(); + if (res.isPresent()) { + return res; + } + + if (clazz.getSuperclass() != Object.class) { + return findMethod(clazz.getSuperclass(), filter); + } else { + return Optional.empty(); + } + } + public interface RBuiltinDiagFactory { String getBuiltinName();