From 516ec369f0376d1a323c995fef770da9d7d7a65b Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Sat, 19 Nov 2016 16:45:11 +0100
Subject: [PATCH] StatsFunctions: support copying of attributes to the result

---
 .../r/library/stats/StatsFunctions.java       | 63 +++++++++++++------
 .../attributes/UnaryCopyAttributesNode.java   |  4 ++
 2 files changed, 47 insertions(+), 20 deletions(-)

diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java
index f1835c62d8..0f80a4d24b 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/StatsFunctions.java
@@ -18,6 +18,8 @@ import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.api.nodes.Node;
 import com.oracle.truffle.api.profiles.BranchProfile;
+import com.oracle.truffle.api.profiles.ConditionProfile;
+import com.oracle.truffle.r.nodes.attributes.UnaryCopyAttributesNode;
 import com.oracle.truffle.r.nodes.builtin.CastBuilder;
 import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
 import com.oracle.truffle.r.runtime.RError;
@@ -49,8 +51,22 @@ public final class StatsFunctions {
         double evaluate(double a, double b, double c, boolean x);
     }
 
+    static final class StatFunctionProfiles {
+        final BranchProfile nan = BranchProfile.create();
+        final NACheck aCheck = NACheck.create();
+        final NACheck bCheck = NACheck.create();
+        final NACheck cCheck = NACheck.create();
+        final ConditionProfile copyAttrsFromA = ConditionProfile.createBinaryProfile();
+        final ConditionProfile copyAttrsFromB = ConditionProfile.createBinaryProfile();
+        final ConditionProfile copyAttrsFromC = ConditionProfile.createBinaryProfile();
+
+        public static StatFunctionProfiles create() {
+            return new StatFunctionProfiles();
+        }
+    }
+
     private static RAbstractDoubleVector evaluate3(Node node, Function3_2 function, RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y,
-                    BranchProfile nan, NACheck aCheck, NACheck bCheck, NACheck cCheck) {
+                    StatFunctionProfiles profiles, UnaryCopyAttributesNode copyAttributesNode) {
         int aLength = a.getLength();
         int bLength = b.getLength();
         int cLength = c.getLength();
@@ -63,17 +79,17 @@ public final class StatsFunctions {
 
         boolean complete = true;
         boolean nans = false;
-        aCheck.enable(a);
-        bCheck.enable(b);
-        cCheck.enable(c);
+        profiles.aCheck.enable(a);
+        profiles.bCheck.enable(b);
+        profiles.cCheck.enable(c);
         for (int i = 0; i < length; i++) {
             double aValue = a.getDataAt(i % aLength);
             double bValue = b.getDataAt(i % bLength);
             double cValue = c.getDataAt(i % cLength);
             double value;
             if (Double.isNaN(aValue) || Double.isNaN(bValue) || Double.isNaN(cValue)) {
-                nan.enter();
-                if (aCheck.check(aValue) || bCheck.check(bValue) || cCheck.check(cValue)) {
+                profiles.nan.enter();
+                if (profiles.aCheck.check(aValue) || profiles.bCheck.check(bValue) || profiles.cCheck.check(cValue)) {
                     value = RRuntime.DOUBLE_NA;
                     complete = false;
                 } else {
@@ -82,7 +98,7 @@ public final class StatsFunctions {
             } else {
                 value = function.evaluate(aValue, bValue, cValue, x, y);
                 if (Double.isNaN(value)) {
-                    nan.enter();
+                    profiles.nan.enter();
                     nans = true;
                 }
             }
@@ -91,7 +107,18 @@ public final class StatsFunctions {
         if (nans) {
             RError.warning(RError.SHOW_CALLER, RError.Message.NAN_PRODUCED);
         }
-        return RDataFactory.createDoubleVector(result, complete);
+        RDoubleVector resultVec = RDataFactory.createDoubleVector(result, complete);
+
+        // copy attributes if necessary:
+        if (profiles.copyAttrsFromA.profile(aLength == length)) {
+            copyAttributesNode.execute(resultVec, a);
+        } else if (profiles.copyAttrsFromB.profile(bLength == length)) {
+            copyAttributesNode.execute(resultVec, b);
+        } else if (profiles.copyAttrsFromC.profile(cLength == length)) {
+            copyAttributesNode.execute(resultVec, c);
+        }
+
+        return resultVec;
     }
 
     public abstract static class Function3_2Node extends RExternalBuiltinNode.Arg5 {
@@ -111,12 +138,10 @@ public final class StatsFunctions {
         }
 
         @Specialization
-        protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y, //
-                        @Cached("create()") BranchProfile nan, //
-                        @Cached("create()") NACheck aCheck, //
-                        @Cached("create()") NACheck bCheck, //
-                        @Cached("create()") NACheck cCheck) {
-            return evaluate3(this, function, a, b, c, x, y, nan, aCheck, bCheck, cCheck);
+        protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, boolean y,
+                        @Cached("create()") StatFunctionProfiles profiles,
+                        @Cached("create()") UnaryCopyAttributesNode copyAttributesNode) {
+            return evaluate3(this, function, a, b, c, x, y, profiles, copyAttributesNode);
         }
     }
 
@@ -136,12 +161,10 @@ public final class StatsFunctions {
         }
 
         @Specialization
-        protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x, //
-                        @Cached("create()") BranchProfile nan, //
-                        @Cached("create()") NACheck aCheck, //
-                        @Cached("create()") NACheck bCheck, //
-                        @Cached("create()") NACheck cCheck) {
-            return evaluate3(this, function, a, b, c, x, false /* dummy */, nan, aCheck, bCheck, cCheck);
+        protected RAbstractDoubleVector evaluate(RAbstractDoubleVector a, RAbstractDoubleVector b, RAbstractDoubleVector c, boolean x,
+                        @Cached("create()") StatFunctionProfiles profiles,
+                        @Cached("create()") UnaryCopyAttributesNode copyAttributesNode) {
+            return evaluate3(this, function, a, b, c, x, false /* dummy */, profiles, copyAttributesNode);
         }
     }
 
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java
index e6f9d802e6..2099f1442e 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/UnaryCopyAttributesNode.java
@@ -49,6 +49,10 @@ public abstract class UnaryCopyAttributesNode extends RBaseNode {
         this.copyAllAttributes = copyAllAttributes;
     }
 
+    public static UnaryCopyAttributesNode create() {
+        return UnaryCopyAttributesNodeGen.create(true);
+    }
+
     public abstract RAbstractVector execute(RAbstractVector target, RAbstractVector left);
 
     protected boolean containsMetadata(RAbstractVector vector, RAttributeProfiles attrProfiles) {
-- 
GitLab