From b48ec8a3890b5432943cad6bdf0a2e7a5034c51f Mon Sep 17 00:00:00 2001
From: Zbynek Slajchrt <zbynek.slajchrt@oracle.com>
Date: Thu, 23 Feb 2017 19:46:33 +0100
Subject: [PATCH] A couple of fixes in builtin diagnostics

---
 .../builtin/ResultTypesAnalyserTest.java      |  21 ++-
 .../r/nodes/casts/ResultTypesAnalyser.java    | 132 ++++++++----------
 .../builtin/casts/ExecutionPathVisitor.java   |  32 ++++-
 3 files changed, 102 insertions(+), 83 deletions(-)

diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java
index d22ab4cd4c..3c2e35c7e5 100644
--- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java
+++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java
@@ -22,6 +22,8 @@
  */
 package com.oracle.truffle.r.nodes.builtin;
 
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.emptyIntegerVector;
+import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullConstant;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.atomicIntegerValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.atomicLogicalValue;
 import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.chain;
@@ -167,9 +169,8 @@ public class ResultTypesAnalyserTest {
     @Test
     public void testBoxPrimitive() {
         arg.boxPrimitive();
-        TypeExpr expected = TypeExpr.union(RNull.class, RMissing.class, RInteger.class, RLogical.class,
-                        RDouble.class, RString.class);
-        expected = expected.or(expected.not());
+        TypeExpr expected = TypeExpr.union(RInteger.class, RLogical.class, RDouble.class, RString.class);
+        expected = expected.or((expected.not().and(atom(String.class).not()).and(atom(Double.class).not()).and(atom(Integer.class).not()).and(atom(Byte.class).not())));
         assertTypes(expected);
     }
 
@@ -428,11 +429,23 @@ public class ResultTypesAnalyserTest {
     }
 
     @Test
-    public void testReturnIf() {
+    public void testReturnIf1() {
         arg.mapIf(nullValue(), mark(constant(1), "m1"), mark(constant("abc"), "m2"));
         assertTypes(atom(String.class).lower(m("m2")).or(atom(Integer.class).lower(m("m1"))));
     }
 
+    @Test
+    public void testReturnIf2() {
+        arg.returnIf(nullValue(), emptyIntegerVector()).returnIf(missingValue(), emptyIntegerVector()).asIntegerVector();
+        assertTypes(atom(int.class).or(atom(RIntSequence.class)).or(atom(RIntVector.class)), true);
+    }
+
+    @Test
+    public void testMustNotBeMissingAndBoxPrimitive() {
+        arg.mustNotBeMissing().returnIf(nullValue(), nullConstant()).mustBe(stringValue()).boxPrimitive().asStringVector();
+        assertTypes(atom(RNull.class).or(atom(RStringVector.class)), true);
+    }
+
     @Test
     public void testAllowMissing() {
         arg.allowMissing().mustBe(stringValue());
diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java
index ba943022bf..5df10cc4e0 100644
--- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java
+++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java
@@ -65,7 +65,6 @@ import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FindFirstStep;
 import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapIfStep;
 import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapStep;
 import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.NotNAStep;
-import com.oracle.truffle.r.nodes.casts.ResultTypesAnalyser.AltTypeExpr;
 import com.oracle.truffle.r.nodes.unary.CastComplexNode;
 import com.oracle.truffle.r.nodes.unary.CastDoubleBaseNode;
 import com.oracle.truffle.r.nodes.unary.CastDoubleNode;
@@ -98,58 +97,27 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
 
-public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> implements MapperVisitor<TypeExpr>, FilterVisitor<TypeExpr>, SubjectVisitor<TypeExpr> {
-
-    static final class AltTypeExpr {
-        final TypeExpr main;
-        final TypeExpr alt;
-
-        private AltTypeExpr(TypeExpr mainBranch, TypeExpr altBranch) {
-            this.main = mainBranch;
-            this.alt = altBranch;
-        }
-
-        static AltTypeExpr create() {
-            return new AltTypeExpr(TypeExpr.ANYTHING, null);
-        }
-
-        TypeExpr merge() {
-            return alt == null ? main : main.or(alt);
-        }
-
-        AltTypeExpr setMain(TypeExpr newMainType) {
-            return newMainType == null ? this : new AltTypeExpr(newMainType, alt);
-        }
-
-        AltTypeExpr addAlt(TypeExpr newAltType) {
-            return newAltType == null ? this : new AltTypeExpr(main, alt == null ? newAltType : alt.or(newAltType));
-        }
-
-        AltTypeExpr or(AltTypeExpr other) {
-            return setMain(main.or(other.main)).addAlt(other.alt);
-        }
-
-    }
+public class ResultTypesAnalyser extends ExecutionPathVisitor<TypeExpr> implements MapperVisitor<TypeExpr>, FilterVisitor<TypeExpr>, SubjectVisitor<TypeExpr> {
 
     private static final TypeExpr NOT_NULL_NOT_MISSING = atom(RNull.class).not().and(atom(RMissing.class).not());
 
     public static TypeExpr analyse(PipelineStep<?, ?> firstStep) {
-        return analyse(firstStep, AltTypeExpr.create()).merge();
+        return analyse(firstStep, TypeExpr.ANYTHING);
     }
 
-    public static AltTypeExpr analyse(PipelineStep<?, ?> firstStep, AltTypeExpr inputType) {
-        List<AltTypeExpr> pathResults = new ResultTypesAnalyser().visitPaths(firstStep, inputType);
+    public static TypeExpr analyse(PipelineStep<?, ?> firstStep, TypeExpr inputType) {
+        List<TypeExpr> pathResults = new ResultTypesAnalyser().visitPaths(firstStep, inputType);
         return pathResults.stream().reduce((x, y) -> x.or(y)).get();
     }
 
     @Override
-    public AltTypeExpr visit(FindFirstStep<?, ?> step, AltTypeExpr inputType) {
+    public TypeExpr visit(FindFirstStep<?, ?> step, TypeExpr inputType) {
         TypeExpr rt;
         if (step.getElementClass() == null || step.getElementClass() == Object.class) {
-            if (inputType.main.isAnything()) {
+            if (inputType.isAnything()) {
                 rt = atom(RAbstractVector.class).not();
             } else {
-                Set<Type> resTypes = inputType.main.classify().stream().map(c -> CastUtils.elementType(c)).collect(Collectors.toSet());
+                Set<Type> resTypes = inputType.classify().stream().map(c -> CastUtils.elementType(c)).collect(Collectors.toSet());
                 rt = TypeExpr.union(resTypes);
             }
         } else {
@@ -167,7 +135,7 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple
             rt = rt.and(atom(RNull.class).not()).and(atom(RMissing.class).not());
         }
 
-        return inputType.setMain(rt);
+        return rt;
     }
 
     private static TypeExpr inferResultTypeFromSpecializations(Class<? extends CastNode> castNodeClass, TypeExpr inputType) {
@@ -175,56 +143,56 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple
     }
 
     @Override
-    public AltTypeExpr visit(CoercionStep<?, ?> step, AltTypeExpr inputType) {
+    public TypeExpr visit(CoercionStep<?, ?> step, TypeExpr inputType) {
         RType type = step.getType();
         TypeExpr res;
         switch (type) {
             case Integer:
-                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastIntegerNode.class, inputType.main) : inferResultTypeFromSpecializations(CastIntegerBaseNode.class, inputType.main);
+                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastIntegerNode.class, inputType) : inferResultTypeFromSpecializations(CastIntegerBaseNode.class, inputType);
                 break;
             case Double:
-                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastDoubleNode.class, inputType.main) : inferResultTypeFromSpecializations(CastDoubleBaseNode.class, inputType.main);
+                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastDoubleNode.class, inputType) : inferResultTypeFromSpecializations(CastDoubleBaseNode.class, inputType);
                 break;
             case Character:
-                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastStringNode.class, inputType.main) : inferResultTypeFromSpecializations(CastStringBaseNode.class, inputType.main);
+                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastStringNode.class, inputType) : inferResultTypeFromSpecializations(CastStringBaseNode.class, inputType);
                 break;
             case Logical:
-                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastLogicalNode.class, inputType.main) : inferResultTypeFromSpecializations(CastLogicalBaseNode.class, inputType.main);
+                res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastLogicalNode.class, inputType) : inferResultTypeFromSpecializations(CastLogicalBaseNode.class, inputType);
                 break;
             case Complex:
-                res = inferResultTypeFromSpecializations(CastComplexNode.class, inputType.main);
+                res = inferResultTypeFromSpecializations(CastComplexNode.class, inputType);
                 break;
             case Raw:
-                res = inferResultTypeFromSpecializations(CastRawNode.class, inputType.main);
+                res = inferResultTypeFromSpecializations(CastRawNode.class, inputType);
                 break;
             case Any:
                 TypeExpr funOrVecTp = atom(RFunction.class).or(atom(RAbstractVector.class));
-                res = inputType.main.and(funOrVecTp);
+                res = inputType.and(funOrVecTp);
                 break;
             default:
                 throw RInternalError.shouldNotReachHere(Utils.stringFormat("Unsupported type '%s' in AsVectorStep", type));
         }
 
         if (step.preserveNonVector) {
-            TypeExpr maskedNullMissing = inputType.main.and(atom(RNull.class).or(atom(RMissing.class)));
+            TypeExpr maskedNullMissing = inputType.and(atom(RNull.class).or(atom(RMissing.class)));
             res = res.or(maskedNullMissing);
         }
 
-        return inputType.setMain(res);
+        return res;
     }
 
     @Override
-    public AltTypeExpr visit(MapStep<?, ?> step, AltTypeExpr inputType) {
-        return inputType.setMain(step.getMapper().accept(this, inputType.main));
+    public TypeExpr visit(MapStep<?, ?> step, TypeExpr inputType) {
+        return step.getMapper().accept(this, inputType);
     }
 
     @Override
-    protected AltTypeExpr visitBranch(MapIfStep<?, ?> step, AltTypeExpr inputType, boolean visitTrueBranch) {
-        TypeExpr filterRes = visitFilter(step.getFilter(), inputType.main);
+    protected TypeExpr visitBranch(MapIfStep<?, ?> step, TypeExpr inputType, boolean visitTrueBranch) {
+        TypeExpr filterRes = visitFilter(step.getFilter(), inputType);
         if (visitTrueBranch) {
             if (step.isReturns()) {
-                AltTypeExpr returnedType = trueBranchResultTypes(step, inputType, filterRes);
-                return inputType.addAlt(returnedType.merge());
+                TypeExpr returnedType = trueBranchResultTypes(step, inputType, filterRes);
+                return returnedType;
             } else {
                 return trueBranchResultTypes(step, inputType, filterRes);
             }
@@ -233,64 +201,74 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple
         }
     }
 
-    private static AltTypeExpr trueBranchResultTypes(MapIfStep<?, ?> step, AltTypeExpr inputType, TypeExpr filterRes) {
-        TypeExpr filterTrueCaseType = inputType.main.and(filterRes);
+    private static TypeExpr trueBranchResultTypes(MapIfStep<?, ?> step, TypeExpr inputType, TypeExpr filterRes) {
+        TypeExpr filterTrueCaseType = inputType.and(filterRes);
         if (step.getTrueBranch() != null) {
-            return analyse(step.getTrueBranch(), inputType.setMain(filterTrueCaseType));
+            return analyse(step.getTrueBranch(), filterTrueCaseType);
         } else {
-            return inputType.setMain(filterTrueCaseType);
+            return filterTrueCaseType;
         }
     }
 
-    private static AltTypeExpr falseBranchResultTypes(MapIfStep<?, ?> step, AltTypeExpr inputType, TypeExpr filterRes) {
-        TypeExpr filterFalseCaseType = inputType.main.and(filterRes.not());
+    private static TypeExpr falseBranchResultTypes(MapIfStep<?, ?> step, TypeExpr inputType, TypeExpr filterRes) {
+        TypeExpr filterFalseCaseType = inputType.and(filterRes.not());
         if (step.getFalseBranch() != null) {
-            return analyse(step.getFalseBranch(), inputType.setMain(filterFalseCaseType));
+            return analyse(step.getFalseBranch(), filterFalseCaseType);
         } else {
-            return inputType.setMain(filterFalseCaseType);
+            return filterFalseCaseType;
         }
     }
 
     @Override
-    public AltTypeExpr visit(FilterStep<?, ?> step, AltTypeExpr inputType) {
+    public TypeExpr visit(FilterStep<?, ?> step, TypeExpr inputType) {
         if (step.isWarning()) {
             return inputType;
         } else {
-            return inputType.setMain(inputType.main.and(visitFilter(step.getFilter(), inputType.main)));
+            return inputType.and(visitFilter(step.getFilter(), inputType));
         }
     }
 
     @Override
-    public AltTypeExpr visit(NotNAStep<?> step, AltTypeExpr inputType) {
-        Set<Object> naSamples = inputType.main.toNormalizedConjunctionSet().stream().filter(t -> t instanceof Class).map(t -> CastUtils.naValue((Class<?>) t)).filter(x -> x != null).collect(
+    public TypeExpr visit(NotNAStep<?> step, TypeExpr inputType) {
+        Set<Object> naSamples = inputType.toNormalizedConjunctionSet().stream().filter(t -> t instanceof Class).map(t -> CastUtils.naValue((Class<?>) t)).filter(x -> x != null).collect(
                         Collectors.toSet());
-        TypeExpr resType = inputType.main.lower(step);
+        TypeExpr resType = inputType.lower(step);
         resType = resType.negativeSamples(naSamples);
         if (step.getReplacement() != null) {
             resType = resType.positiveSamples(step.getReplacement());
         }
-        return inputType.setMain(resType);
+        return resType;
     }
 
     @Override
-    public AltTypeExpr visit(DefaultErrorStep<?> step, AltTypeExpr inputType) {
+    public TypeExpr visit(DefaultErrorStep<?> step, TypeExpr inputType) {
         return inputType;
     }
 
     @Override
-    public AltTypeExpr visit(DefaultWarningStep<?> step, AltTypeExpr inputType) {
+    public TypeExpr visit(DefaultWarningStep<?> step, TypeExpr inputType) {
         return inputType;
     }
 
     @Override
-    public AltTypeExpr visit(BoxPrimitiveStep<?> step, AltTypeExpr inputType) {
-        TypeExpr res = TypeExpr.union(RNull.class, RMissing.class, RInteger.class, RLogical.class, RDouble.class, RString.class);
-        return inputType.setMain(res.or(res.not()));
+    public TypeExpr visit(BoxPrimitiveStep<?> step, TypeExpr inputType) {
+        TypeExpr noPrimType = atom(Integer.class).not().and(atom(Byte.class).not()).and(atom(Double.class).not()).and(atom(String.class).not());
+        // cancel potential primitive types in the input
+        TypeExpr noPrimInput = inputType.and(noPrimType);
+        // the positive output type union
+        TypeExpr res = TypeExpr.union(RInteger.class, RLogical.class, RDouble.class, RString.class);
+        // intersect the to stop propagating the primitive types, such as String
+        res = res.and(noPrimInput);
+        // the output of the boxing is actually the union of the positive union with its negation
+        // that represents the fallback output for non-vectors
+        TypeExpr negRes = res.not().and(noPrimInput);
+        res = res.or(negRes);
+        return res;
     }
 
     @Override
-    public AltTypeExpr visit(AttributableCoercionStep<?> step, AltTypeExpr inputType) {
-        return inputType.setMain(inferResultTypeFromSpecializations(CastToAttributableNode.class, inputType.main));
+    public TypeExpr visit(AttributableCoercionStep<?> step, TypeExpr inputType) {
+        return inferResultTypeFromSpecializations(CastToAttributableNode.class, inputType);
     }
 
     // MapperVisitor
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java
index fc8abb28d8..4c163d8518 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java
@@ -39,6 +39,7 @@ public abstract class ExecutionPathVisitor<T> implements PipelineStepVisitor<T>
     private int mapIfCounter;
     private List<T> results = new ArrayList<>();
 
+    @SuppressWarnings("unchecked")
     public List<T> visitPaths(PipelineStep<?, ?> firstStep, T initial) {
         if (firstStep == null) {
             return Collections.singletonList(initial);
@@ -48,7 +49,13 @@ public abstract class ExecutionPathVisitor<T> implements PipelineStepVisitor<T>
         int n = 1 << mapIfStepStatuses.size();
         for (long i = 1; i < n; i++) {
             bs = BitSet.valueOf(new long[]{i});
-            results.add(firstStep.acceptPipeline(this, initial));
+            T res;
+            try {
+                res = firstStep.acceptPipeline(this, initial);
+            } catch (PathBreakException br) {
+                res = (T) br.result;
+            }
+            results.add(res);
         }
         return results;
     }
@@ -62,9 +69,30 @@ public abstract class ExecutionPathVisitor<T> implements PipelineStepVisitor<T>
         } else {
             visitTrueBranch = bs.get(mapIfStepStatuses.get(step));
         }
-        return visitBranch(step, previous, visitTrueBranch);
+        T res = visitBranch(step, previous, visitTrueBranch);
+        if (step.isReturns() && visitTrueBranch) {
+            throw new PathBreakException(res);
+        } else {
+            return res;
+        }
     }
 
     protected abstract T visitBranch(MapIfStep<?, ?> step, T previous, boolean visitTrueBranch);
 
+    @SuppressWarnings("serial")
+    static final class PathBreakException extends RuntimeException {
+
+        private final Object result;
+
+        private PathBreakException(Object result) {
+            this.result = result;
+        }
+
+        @SuppressWarnings("sync-override")
+        @Override
+        public Throwable fillInStackTrace() {
+            return null;
+        }
+    }
+
 }
-- 
GitLab