From 516ec369f0376d1a323c995fef770da9d7d7a65b Mon Sep 17 00:00:00 2001 From: stepan <stepan.sindelar@oracle.com> Date: Sat, 19 Nov 2016 16:45:11 +0100 Subject: [PATCH] StatsFunctions: support copying of attributes to the result --- .../r/library/stats/StatsFunctions.java | 63 +++++++++++++------ .../attributes/UnaryCopyAttributesNode.java | 4 ++ 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java index f1835c62d8..0f80a4d24b 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java @@ -18,6 +18,8 @@ import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.nodes.Node; import com.oracle.truffle.api.profiles.BranchProfile; +import com.oracle.truffle.api.profiles.ConditionProfile; +import com.oracle.truffle.r.nodes.attributes.UnaryCopyAttributesNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.runtime.RError; @@ -49,8 +51,22 @@ public final class StatsFunctions { double evaluate(double a, double b, double c, boolean x); } + static final class StatFunctionProfiles { + final BranchProfile nan = BranchProfile.create(); + final NACheck aCheck = NACheck.create(); + final NACheck bCheck = NACheck.create(); + final NACheck cCheck = NACheck.create(); + final ConditionProfile copyAttrsFromA = ConditionProfile.createBinaryProfile(); + final ConditionProfile copyAttrsFromB = ConditionProfile.createBinaryProfile(); + final ConditionProfile copyAttrsFromC = ConditionProfile.createBinaryProfile(); + + public static StatFunctionProfiles create() { + return new StatFunctionProfiles(); + } + } + private static RAbstractDoubleVector evaluate3(Node node, Function3_2 function, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y, - BranchProfile nan, NACheck aCheck, NACheck bCheck, NACheck cCheck) { + StatFunctionProfiles profiles, UnaryCopyAttributesNode copyAttributesNode) { int aLength = a.getLength(); int bLength = b.getLength(); int cLength = c.getLength(); @@ -63,17 +79,17 @@ public final class StatsFunctions { boolean complete = true; boolean nans = false; - aCheck.enable(a); - bCheck.enable(b); - cCheck.enable(c); + profiles.aCheck.enable(a); + profiles.bCheck.enable(b); + profiles.cCheck.enable(c); for (int i = 0; i < length; i++) { double aValue = a.getDataAt(i % aLength); double bValue = b.getDataAt(i % bLength); double cValue = c.getDataAt(i % cLength); double value; if (Double.isNaN(aValue) || Double.isNaN(bValue) || Double.isNaN(cValue)) { - nan.enter(); - if (aCheck.check(aValue) || bCheck.check(bValue) || cCheck.check(cValue)) { + profiles.nan.enter(); + if (profiles.aCheck.check(aValue) || profiles.bCheck.check(bValue) || profiles.cCheck.check(cValue)) { value = RRuntime.DOUBLE_NA; complete = false; } else { @@ -82,7 +98,7 @@ public final class StatsFunctions { } else { value = function.evaluate(aValue, bValue, cValue, x, y); if (Double.isNaN(value)) { - nan.enter(); + profiles.nan.enter(); nans = true; } } @@ -91,7 +107,18 @@ public final class StatsFunctions { if (nans) { RError.warning(RError.SHOW_CALLER, RError.Message.NAN_PRODUCED); } - return RDataFactory.createDoubleVector(result, complete); + RDoubleVector resultVec = RDataFactory.createDoubleVector(result, complete); + + // copy attributes if necessary: + if (profiles.copyAttrsFromA.profile(aLength == length)) { + copyAttributesNode.execute(resultVec, a); + } else if (profiles.copyAttrsFromB.profile(bLength == length)) { + copyAttributesNode.execute(resultVec, b); + } else if (profiles.copyAttrsFromC.profile(cLength == length)) { + copyAttributesNode.execute(resultVec, c); + } + + return resultVec; } public abstract static class Function3_2Node extends RExternalBuiltinNode.Arg5 { @@ -111,12 +138,10 @@ public final class StatsFunctions { } @Specialization - protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y, // - @Cached("create()") BranchProfile nan, // - @Cached("create()") NACheck aCheck, // - @Cached("create()") NACheck bCheck, // - @Cached("create()") NACheck cCheck) { - return evaluate3(this, function, a, b, c, x, y, nan, aCheck, bCheck, cCheck); + protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y, + @Cached("create()") StatFunctionProfiles profiles, + @Cached("create()") UnaryCopyAttributesNode copyAttributesNode) { + return evaluate3(this, function, a, b, c, x, y, profiles, copyAttributesNode); } } @@ -136,12 +161,10 @@ public final class StatsFunctions { } @Specialization - protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, // - @Cached("create()") BranchProfile nan, // - @Cached("create()") NACheck aCheck, // - @Cached("create()") NACheck bCheck, // - @Cached("create()") NACheck cCheck) { - return evaluate3(this, function, a, b, c, x, false /* dummy */, nan, aCheck, bCheck, cCheck); + protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, + @Cached("create()") StatFunctionProfiles profiles, + @Cached("create()") UnaryCopyAttributesNode copyAttributesNode) { + return evaluate3(this, function, a, b, c, x, false /* dummy */, profiles, copyAttributesNode); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java index e6f9d802e6..2099f1442e 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java @@ -49,6 +49,10 @@ public abstract class UnaryCopyAttributesNode extends RBaseNode { this.copyAllAttributes = copyAllAttributes; } + public static UnaryCopyAttributesNode create() { + return UnaryCopyAttributesNodeGen.create(true); + } + public abstract RAbstractVector execute(RAbstractVector target, RAbstractVector left); protected boolean containsMetadata(RAbstractVector vector, RAttributeProfiles attrProfiles) { -- GitLab