From b48ec8a3890b5432943cad6bdf0a2e7a5034c51f Mon Sep 17 00:00:00 2001 From: Zbynek Slajchrt <zbynek.slajchrt@oracle.com> Date: Thu, 23 Feb 2017 19:46:33 +0100 Subject: [PATCH] A couple of fixes in builtin diagnostics --- .../builtin/ResultTypesAnalyserTest.java | 21 ++- .../r/nodes/casts/ResultTypesAnalyser.java | 132 ++++++++---------- .../builtin/casts/ExecutionPathVisitor.java | 32 ++++- 3 files changed, 102 insertions(+), 83 deletions(-) diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java index d22ab4cd4c..3c2e35c7e5 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/builtin/ResultTypesAnalyserTest.java @@ -22,6 +22,8 @@ */ package com.oracle.truffle.r.nodes.builtin; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.emptyIntegerVector; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullConstant; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.atomicIntegerValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.atomicLogicalValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.chain; @@ -167,9 +169,8 @@ public class ResultTypesAnalyserTest { @Test public void testBoxPrimitive() { arg.boxPrimitive(); - TypeExpr expected = TypeExpr.union(RNull.class, RMissing.class, RInteger.class, RLogical.class, - RDouble.class, RString.class); - expected = expected.or(expected.not()); + TypeExpr expected = TypeExpr.union(RInteger.class, RLogical.class, RDouble.class, RString.class); + expected = expected.or((expected.not().and(atom(String.class).not()).and(atom(Double.class).not()).and(atom(Integer.class).not()).and(atom(Byte.class).not()))); assertTypes(expected); } @@ -428,11 +429,23 @@ public class ResultTypesAnalyserTest { } @Test - public void testReturnIf() { + public void testReturnIf1() { arg.mapIf(nullValue(), mark(constant(1), "m1"), mark(constant("abc"), "m2")); assertTypes(atom(String.class).lower(m("m2")).or(atom(Integer.class).lower(m("m1")))); } + @Test + public void testReturnIf2() { + arg.returnIf(nullValue(), emptyIntegerVector()).returnIf(missingValue(), emptyIntegerVector()).asIntegerVector(); + assertTypes(atom(int.class).or(atom(RIntSequence.class)).or(atom(RIntVector.class)), true); + } + + @Test + public void testMustNotBeMissingAndBoxPrimitive() { + arg.mustNotBeMissing().returnIf(nullValue(), nullConstant()).mustBe(stringValue()).boxPrimitive().asStringVector(); + assertTypes(atom(RNull.class).or(atom(RStringVector.class)), true); + } + @Test public void testAllowMissing() { arg.allowMissing().mustBe(stringValue()); diff --git a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java index ba943022bf..5df10cc4e0 100644 --- a/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java +++ b/com.oracle.truffle.r.nodes.test/src/com/oracle/truffle/r/nodes/casts/ResultTypesAnalyser.java @@ -65,7 +65,6 @@ import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.FindFirstStep; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapIfStep; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.MapStep; import com.oracle.truffle.r.nodes.builtin.casts.PipelineStep.NotNAStep; -import com.oracle.truffle.r.nodes.casts.ResultTypesAnalyser.AltTypeExpr; import com.oracle.truffle.r.nodes.unary.CastComplexNode; import com.oracle.truffle.r.nodes.unary.CastDoubleBaseNode; import com.oracle.truffle.r.nodes.unary.CastDoubleNode; @@ -98,58 +97,27 @@ import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> implements MapperVisitor<TypeExpr>, FilterVisitor<TypeExpr>, SubjectVisitor<TypeExpr> { - - static final class AltTypeExpr { - final TypeExpr main; - final TypeExpr alt; - - private AltTypeExpr(TypeExpr mainBranch, TypeExpr altBranch) { - this.main = mainBranch; - this.alt = altBranch; - } - - static AltTypeExpr create() { - return new AltTypeExpr(TypeExpr.ANYTHING, null); - } - - TypeExpr merge() { - return alt == null ? main : main.or(alt); - } - - AltTypeExpr setMain(TypeExpr newMainType) { - return newMainType == null ? this : new AltTypeExpr(newMainType, alt); - } - - AltTypeExpr addAlt(TypeExpr newAltType) { - return newAltType == null ? this : new AltTypeExpr(main, alt == null ? newAltType : alt.or(newAltType)); - } - - AltTypeExpr or(AltTypeExpr other) { - return setMain(main.or(other.main)).addAlt(other.alt); - } - - } +public class ResultTypesAnalyser extends ExecutionPathVisitor<TypeExpr> implements MapperVisitor<TypeExpr>, FilterVisitor<TypeExpr>, SubjectVisitor<TypeExpr> { private static final TypeExpr NOT_NULL_NOT_MISSING = atom(RNull.class).not().and(atom(RMissing.class).not()); public static TypeExpr analyse(PipelineStep<?, ?> firstStep) { - return analyse(firstStep, AltTypeExpr.create()).merge(); + return analyse(firstStep, TypeExpr.ANYTHING); } - public static AltTypeExpr analyse(PipelineStep<?, ?> firstStep, AltTypeExpr inputType) { - List<AltTypeExpr> pathResults = new ResultTypesAnalyser().visitPaths(firstStep, inputType); + public static TypeExpr analyse(PipelineStep<?, ?> firstStep, TypeExpr inputType) { + List<TypeExpr> pathResults = new ResultTypesAnalyser().visitPaths(firstStep, inputType); return pathResults.stream().reduce((x, y) -> x.or(y)).get(); } @Override - public AltTypeExpr visit(FindFirstStep<?, ?> step, AltTypeExpr inputType) { + public TypeExpr visit(FindFirstStep<?, ?> step, TypeExpr inputType) { TypeExpr rt; if (step.getElementClass() == null || step.getElementClass() == Object.class) { - if (inputType.main.isAnything()) { + if (inputType.isAnything()) { rt = atom(RAbstractVector.class).not(); } else { - Set<Type> resTypes = inputType.main.classify().stream().map(c -> CastUtils.elementType(c)).collect(Collectors.toSet()); + Set<Type> resTypes = inputType.classify().stream().map(c -> CastUtils.elementType(c)).collect(Collectors.toSet()); rt = TypeExpr.union(resTypes); } } else { @@ -167,7 +135,7 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple rt = rt.and(atom(RNull.class).not()).and(atom(RMissing.class).not()); } - return inputType.setMain(rt); + return rt; } private static TypeExpr inferResultTypeFromSpecializations(Class<? extends CastNode> castNodeClass, TypeExpr inputType) { @@ -175,56 +143,56 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple } @Override - public AltTypeExpr visit(CoercionStep<?, ?> step, AltTypeExpr inputType) { + public TypeExpr visit(CoercionStep<?, ?> step, TypeExpr inputType) { RType type = step.getType(); TypeExpr res; switch (type) { case Integer: - res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastIntegerNode.class, inputType.main) : inferResultTypeFromSpecializations(CastIntegerBaseNode.class, inputType.main); + res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastIntegerNode.class, inputType) : inferResultTypeFromSpecializations(CastIntegerBaseNode.class, inputType); break; case Double: - res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastDoubleNode.class, inputType.main) : inferResultTypeFromSpecializations(CastDoubleBaseNode.class, inputType.main); + res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastDoubleNode.class, inputType) : inferResultTypeFromSpecializations(CastDoubleBaseNode.class, inputType); break; case Character: - res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastStringNode.class, inputType.main) : inferResultTypeFromSpecializations(CastStringBaseNode.class, inputType.main); + res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastStringNode.class, inputType) : inferResultTypeFromSpecializations(CastStringBaseNode.class, inputType); break; case Logical: - res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastLogicalNode.class, inputType.main) : inferResultTypeFromSpecializations(CastLogicalBaseNode.class, inputType.main); + res = step.vectorCoercion ? inferResultTypeFromSpecializations(CastLogicalNode.class, inputType) : inferResultTypeFromSpecializations(CastLogicalBaseNode.class, inputType); break; case Complex: - res = inferResultTypeFromSpecializations(CastComplexNode.class, inputType.main); + res = inferResultTypeFromSpecializations(CastComplexNode.class, inputType); break; case Raw: - res = inferResultTypeFromSpecializations(CastRawNode.class, inputType.main); + res = inferResultTypeFromSpecializations(CastRawNode.class, inputType); break; case Any: TypeExpr funOrVecTp = atom(RFunction.class).or(atom(RAbstractVector.class)); - res = inputType.main.and(funOrVecTp); + res = inputType.and(funOrVecTp); break; default: throw RInternalError.shouldNotReachHere(Utils.stringFormat("Unsupported type '%s' in AsVectorStep", type)); } if (step.preserveNonVector) { - TypeExpr maskedNullMissing = inputType.main.and(atom(RNull.class).or(atom(RMissing.class))); + TypeExpr maskedNullMissing = inputType.and(atom(RNull.class).or(atom(RMissing.class))); res = res.or(maskedNullMissing); } - return inputType.setMain(res); + return res; } @Override - public AltTypeExpr visit(MapStep<?, ?> step, AltTypeExpr inputType) { - return inputType.setMain(step.getMapper().accept(this, inputType.main)); + public TypeExpr visit(MapStep<?, ?> step, TypeExpr inputType) { + return step.getMapper().accept(this, inputType); } @Override - protected AltTypeExpr visitBranch(MapIfStep<?, ?> step, AltTypeExpr inputType, boolean visitTrueBranch) { - TypeExpr filterRes = visitFilter(step.getFilter(), inputType.main); + protected TypeExpr visitBranch(MapIfStep<?, ?> step, TypeExpr inputType, boolean visitTrueBranch) { + TypeExpr filterRes = visitFilter(step.getFilter(), inputType); if (visitTrueBranch) { if (step.isReturns()) { - AltTypeExpr returnedType = trueBranchResultTypes(step, inputType, filterRes); - return inputType.addAlt(returnedType.merge()); + TypeExpr returnedType = trueBranchResultTypes(step, inputType, filterRes); + return returnedType; } else { return trueBranchResultTypes(step, inputType, filterRes); } @@ -233,64 +201,74 @@ public class ResultTypesAnalyser extends ExecutionPathVisitor<AltTypeExpr> imple } } - private static AltTypeExpr trueBranchResultTypes(MapIfStep<?, ?> step, AltTypeExpr inputType, TypeExpr filterRes) { - TypeExpr filterTrueCaseType = inputType.main.and(filterRes); + private static TypeExpr trueBranchResultTypes(MapIfStep<?, ?> step, TypeExpr inputType, TypeExpr filterRes) { + TypeExpr filterTrueCaseType = inputType.and(filterRes); if (step.getTrueBranch() != null) { - return analyse(step.getTrueBranch(), inputType.setMain(filterTrueCaseType)); + return analyse(step.getTrueBranch(), filterTrueCaseType); } else { - return inputType.setMain(filterTrueCaseType); + return filterTrueCaseType; } } - private static AltTypeExpr falseBranchResultTypes(MapIfStep<?, ?> step, AltTypeExpr inputType, TypeExpr filterRes) { - TypeExpr filterFalseCaseType = inputType.main.and(filterRes.not()); + private static TypeExpr falseBranchResultTypes(MapIfStep<?, ?> step, TypeExpr inputType, TypeExpr filterRes) { + TypeExpr filterFalseCaseType = inputType.and(filterRes.not()); if (step.getFalseBranch() != null) { - return analyse(step.getFalseBranch(), inputType.setMain(filterFalseCaseType)); + return analyse(step.getFalseBranch(), filterFalseCaseType); } else { - return inputType.setMain(filterFalseCaseType); + return filterFalseCaseType; } } @Override - public AltTypeExpr visit(FilterStep<?, ?> step, AltTypeExpr inputType) { + public TypeExpr visit(FilterStep<?, ?> step, TypeExpr inputType) { if (step.isWarning()) { return inputType; } else { - return inputType.setMain(inputType.main.and(visitFilter(step.getFilter(), inputType.main))); + return inputType.and(visitFilter(step.getFilter(), inputType)); } } @Override - public AltTypeExpr visit(NotNAStep<?> step, AltTypeExpr inputType) { - Set<Object> naSamples = inputType.main.toNormalizedConjunctionSet().stream().filter(t -> t instanceof Class).map(t -> CastUtils.naValue((Class<?>) t)).filter(x -> x != null).collect( + public TypeExpr visit(NotNAStep<?> step, TypeExpr inputType) { + Set<Object> naSamples = inputType.toNormalizedConjunctionSet().stream().filter(t -> t instanceof Class).map(t -> CastUtils.naValue((Class<?>) t)).filter(x -> x != null).collect( Collectors.toSet()); - TypeExpr resType = inputType.main.lower(step); + TypeExpr resType = inputType.lower(step); resType = resType.negativeSamples(naSamples); if (step.getReplacement() != null) { resType = resType.positiveSamples(step.getReplacement()); } - return inputType.setMain(resType); + return resType; } @Override - public AltTypeExpr visit(DefaultErrorStep<?> step, AltTypeExpr inputType) { + public TypeExpr visit(DefaultErrorStep<?> step, TypeExpr inputType) { return inputType; } @Override - public AltTypeExpr visit(DefaultWarningStep<?> step, AltTypeExpr inputType) { + public TypeExpr visit(DefaultWarningStep<?> step, TypeExpr inputType) { return inputType; } @Override - public AltTypeExpr visit(BoxPrimitiveStep<?> step, AltTypeExpr inputType) { - TypeExpr res = TypeExpr.union(RNull.class, RMissing.class, RInteger.class, RLogical.class, RDouble.class, RString.class); - return inputType.setMain(res.or(res.not())); + public TypeExpr visit(BoxPrimitiveStep<?> step, TypeExpr inputType) { + TypeExpr noPrimType = atom(Integer.class).not().and(atom(Byte.class).not()).and(atom(Double.class).not()).and(atom(String.class).not()); + // cancel potential primitive types in the input + TypeExpr noPrimInput = inputType.and(noPrimType); + // the positive output type union + TypeExpr res = TypeExpr.union(RInteger.class, RLogical.class, RDouble.class, RString.class); + // intersect the to stop propagating the primitive types, such as String + res = res.and(noPrimInput); + // the output of the boxing is actually the union of the positive union with its negation + // that represents the fallback output for non-vectors + TypeExpr negRes = res.not().and(noPrimInput); + res = res.or(negRes); + return res; } @Override - public AltTypeExpr visit(AttributableCoercionStep<?> step, AltTypeExpr inputType) { - return inputType.setMain(inferResultTypeFromSpecializations(CastToAttributableNode.class, inputType.main)); + public TypeExpr visit(AttributableCoercionStep<?> step, TypeExpr inputType) { + return inferResultTypeFromSpecializations(CastToAttributableNode.class, inputType); } // MapperVisitor diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java index fc8abb28d8..4c163d8518 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/casts/ExecutionPathVisitor.java @@ -39,6 +39,7 @@ public abstract class ExecutionPathVisitor<T> implements PipelineStepVisitor<T> private int mapIfCounter; private List<T> results = new ArrayList<>(); + @SuppressWarnings("unchecked") public List<T> visitPaths(PipelineStep<?, ?> firstStep, T initial) { if (firstStep == null) { return Collections.singletonList(initial); @@ -48,7 +49,13 @@ public abstract class ExecutionPathVisitor<T> implements PipelineStepVisitor<T> int n = 1 << mapIfStepStatuses.size(); for (long i = 1; i < n; i++) { bs = BitSet.valueOf(new long[]{i}); - results.add(firstStep.acceptPipeline(this, initial)); + T res; + try { + res = firstStep.acceptPipeline(this, initial); + } catch (PathBreakException br) { + res = (T) br.result; + } + results.add(res); } return results; } @@ -62,9 +69,30 @@ public abstract class ExecutionPathVisitor<T> implements PipelineStepVisitor<T> } else { visitTrueBranch = bs.get(mapIfStepStatuses.get(step)); } - return visitBranch(step, previous, visitTrueBranch); + T res = visitBranch(step, previous, visitTrueBranch); + if (step.isReturns() && visitTrueBranch) { + throw new PathBreakException(res); + } else { + return res; + } } protected abstract T visitBranch(MapIfStep<?, ?> step, T previous, boolean visitTrueBranch); + @SuppressWarnings("serial") + static final class PathBreakException extends RuntimeException { + + private final Object result; + + private PathBreakException(Object result) { + this.result = result; + } + + @SuppressWarnings("sync-override") + @Override + public Throwable fillInStackTrace() { + return null; + } + } + } -- GitLab