From 8308f20726241ef099c46de83e655e9a28f7693a Mon Sep 17 00:00:00 2001
From: Michael Haupt <michael.haupt@oracle.com>
Date: Thu, 30 Jan 2014 22:37:31 +0100
Subject: [PATCH] fix bug in cor implementation

---
 .../truffle/r/nodes/builtin/base/Cor.java     |   2 +-
 .../truffle/r/nodes/builtin/base/Cov.java     |   2 +-
 .../truffle/r/nodes/builtin/base/Covcor.java  | 111 +++++++++++++++++-
 .../truffle/r/test/ExpectedTestOutput.test    |   8 ++
 .../oracle/truffle/r/test/all/AllTests.java   |  10 +-
 .../truffle/r/test/failing/FailingTests.java  |   5 -
 .../r/test/simple/TestSimpleBuiltins.java     |   2 +-
 7 files changed, 123 insertions(+), 17 deletions(-)

diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java
index 5fb1932a0c..5c7430e071 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java
@@ -37,6 +37,6 @@ public abstract class Cor extends RBuiltinNode {
     @Specialization
     @SuppressWarnings("unused")
     public RDoubleVector dimWithDimensions(RDoubleVector vector1, RMissing vector2) {
-        return Covcor.cor(vector1, vector1, false, this.getEncapsulatingSourceSection());
+        return Covcor.cor(vector1, null, false, this.getEncapsulatingSourceSection());
     }
 }
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java
index b765919ea7..c054f2581b 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java
@@ -37,6 +37,6 @@ public abstract class Cov extends RBuiltinNode {
     @Specialization
     @SuppressWarnings("unused")
     public RDoubleVector dimWithDimensions(RDoubleVector vector1, RMissing vector2) {
-        return Covcor.cov(vector1, vector1, false, this.getEncapsulatingSourceSection());
+        return Covcor.cov(vector1, null, false, this.getEncapsulatingSourceSection());
     }
 }
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java
index 73bb82050b..36bf738ef1 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java
@@ -12,6 +12,7 @@
 package com.oracle.truffle.r.nodes.builtin.base;
 
 import com.oracle.truffle.api.*;
+import com.oracle.truffle.api.CompilerDirectives.SlowPath;
 import com.oracle.truffle.r.runtime.*;
 import com.oracle.truffle.r.runtime.data.*;
 import com.oracle.truffle.r.runtime.data.model.*;
@@ -32,7 +33,7 @@ public class Covcor {
         return corcov(x, y, iskendall, false, source);
     }
 
-    @com.oracle.truffle.api.CompilerDirectives.SlowPath
+    @SlowPath
     private static RDoubleVector corcov(RDoubleVector x, RDoubleVector y, boolean iskendall, boolean cor, SourceSection source) {
         boolean ansmat;
         boolean naFail;
@@ -51,7 +52,9 @@ public class Covcor {
             ncx = 1;
         }
 
-        if (isMatrix(y)) {
+        if (y == null) {
+            ncy = ncx;
+        } else if (isMatrix(y)) {
             if (nrows(y) != n) {
                 error("incompatible dimensions");
             }
@@ -82,8 +85,13 @@ public class Covcor {
         double[] ym = new double[ncy];
         RIntVector ind = RDataFactory.createIntVector(n);
 
-        complete2(n, ncx, ncy, x, y, ind, naFail);
-        sd0 = covComplete2(n, ncx, ncy, x, y, xm, ym, ind, answerData, cor, iskendall);
+        if (y == null) {
+            complete1(n, ncx, x, ind, naFail);
+            sd0 = covComplete1(n, ncx, x, xm, ind, answerData, cor, iskendall);
+        } else {
+            complete2(n, ncx, ncy, x, y, ind, naFail);
+            sd0 = covComplete2(n, ncx, ncy, x, y, xm, ym, ind, answerData, cor, iskendall);
+        }
 
         if (sd0) { /* only in cor() */
             RError.warning(source, RError.SD_ZERO);
@@ -108,6 +116,26 @@ public class Covcor {
         return x.getDimensions()[0];
     }
 
+    private static void complete1(int n, int ncx, RDoubleVector x, RIntVector ind, boolean naFail) {
+        int i;
+        int j;
+        for (i = 0; i < n; i++) {
+            ind.updateDataAt(i, 1, check);
+        }
+        for (j = 0; j < ncx; j++) {
+            // z = &x[j * n];
+            for (i = 0; i < n; i++) {
+                if (Double.isNaN(x.getDataAt(j * n + i))) {
+                    if (naFail) {
+                        error("missing observations in cov/cor");
+                    } else {
+                        ind.updateDataAt(i, 0, check);
+                    }
+                }
+            }
+        }
+    }
+
     private static void complete2(int n, int ncx, int ncy, RDoubleVector x, RDoubleVector y, RIntVector ind, boolean naFail) {
         int i;
         int j;
@@ -141,6 +169,81 @@ public class Covcor {
         }
     }
 
+    private static boolean covComplete1(int n, int ncx, RDoubleVector x, double[] xm, RIntVector indInput, double[] ans, boolean cor, boolean kendall) {
+        int n1 = -1;
+        int nobs;
+        boolean isSd0 = false;
+
+        /* total number of complete observations */
+        nobs = 0;
+        for (int k = 0; k < n; k++) {
+            if (indInput.getDataAt(k) != 0) {
+                nobs++;
+            }
+        }
+        if (nobs <= 1) { /* too many missing */
+            for (int i = 0; i < ans.length; i++) {
+                ans[i] = RRuntime.DOUBLE_NA;
+            }
+            return isSd0;
+        }
+
+        RIntVector ind = indInput;
+        if (nobs == ind.getLength()) {
+            // No values of ind are zeroed.
+            ind = null;
+        }
+
+        if (!kendall) {
+            mean(x, xm, ind, n, ncx, nobs);
+            n1 = nobs - 1;
+        }
+        for (int i = 0; i < ncx; i++) {
+            if (!kendall) {
+                double xxm = xm[i];
+                for (int j = 0; j <= i; j++) {
+                    double yym = xm[j];
+                    double sum = 0.0;
+                    for (int k = 0; k < n; k++) {
+                        if (ind == null || ind.getDataAt(k) != 0) {
+                            sum += (x.getDataAt(i * n + k) - xxm) * (x.getDataAt(j * n + k) - yym);
+                        }
+                    }
+                    double r = sum / n1;
+                    ans[i + j * ncx] = r;
+                    ans[j + i * ncx] = r;
+                }
+            } else { /* Kendall's tau */
+                throw new UnsupportedOperationException("kendall's unsupported");
+            }
+        }
+
+        if (cor) {
+            for (int i = 0; i < ncx; i++) {
+                xm[i] = Math.sqrt(ans[i + i * ncx]);
+            }
+            for (int i = 0; i < ncx; i++) {
+                for (int j = 0; j < i; j++) {
+                    if (xm[i] == 0 || xm[j] == 0) {
+                        isSd0 = true;
+                        ans[i + j * ncx] = RRuntime.DOUBLE_NA;
+                        ans[j + i * ncx] = RRuntime.DOUBLE_NA;
+                    } else {
+                        double sum = ans[i + j * ncx] / (xm[i] * xm[j]);
+                        if (sum > 1.0) {
+                            sum = 1.0;
+                        }
+                        ans[i + j * ncx] = sum;
+                        ans[j + i * ncx] = sum;
+                    }
+                }
+                ans[i + i * ncx] = 1.0;
+            }
+        }
+
+        return isSd0;
+    }
+
     private static boolean covComplete2(int n, int ncx, int ncy, RDoubleVector x, RDoubleVector y, double[] xm, double[] ym, RIntVector indInput, double[] ans, boolean cor, boolean kendall) {
         int n1 = -1;
         int nobs;
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
index 18c6f6335e..c95e78ca32 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test
@@ -3654,6 +3654,14 @@ a1 a2
 #{ cor(c(1,2,3),c(1,2,3)) }
 [1] 1
 
+##com.oracle.truffle.r.test.simple.TestSimpleBuiltins.testCor
+#{ cor(cbind(c(1,1,1), c(1,1,1))) }
+     [,1] [,2]
+[1,]    1   NA
+[2,]   NA    1
+Warning message:
+In cor(cbind(c(1, 1, 1), c(1, 1, 1))) : the standard deviation is zero
+
 ##com.oracle.truffle.r.test.simple.TestSimpleBuiltins.testCor
 #{ cor(cbind(c(3,2,1), c(1,2,3))) }
      [,1] [,2]
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java
index b761810493..b290d2b34a 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java
@@ -5548,6 +5548,11 @@ public class AllTests extends TestBase {
         assertEval("{ cor(cbind(c(3,2,1), c(1,2,3))) }");
     }
 
+    @Test
+    public void TestSimpleBuiltins_testCor_a9306c39144eb95725c311083b3248ba() {
+        assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }");
+    }
+
     @Test
     public void TestSimpleBuiltins_testCorIgnore_564c5ee2d2eea4a4b168dca5e6fa9e4f() {
         assertEval("{ cor(cbind(c(1:9,0/0), 101:110)) }");
@@ -5558,11 +5563,6 @@ public class AllTests extends TestBase {
         assertEval("{ round( cor(cbind(c(10,5,4,1), c(2,5,10,5))), digits=5 ) }");
     }
 
-    @Test
-    public void TestSimpleBuiltins_testCorIgnore_a9306c39144eb95725c311083b3248ba() {
-        assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }");
-    }
-
     @Test
     public void TestSimpleBuiltins_testCov_4b96d1c7c503defdec6ebab5b659625c() {
         assertEval("{ cov(c(1,2,3),c(1,2,3)) }");
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java
index db64573384..4946aa2d34 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java
@@ -1798,11 +1798,6 @@ public class FailingTests extends TestBase {
         assertEval("{ round( cor(cbind(c(10,5,4,1), c(2,5,10,5))), digits=5 ) }");
     }
 
-    @Ignore
-    public void TestSimpleBuiltins_testCorIgnore_a9306c39144eb95725c311083b3248ba() {
-        assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }");
-    }
-
     @Ignore
     public void TestSimpleBuiltins_testCrossprod_7f9549017d66ad3dd1583536fa7183d7() {
         assertEval("{ x <- 1:6 ; crossprod(x) }");
diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java
index b49278cdf2..80a1d34b20 100644
--- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java
+++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java
@@ -1777,6 +1777,7 @@ public class TestSimpleBuiltins extends TestBase {
         assertEval("{ cor(c(1,2,3),c(1,2,3)) }");
         assertEval("{ as.integer(cor(c(1,2,3),c(1,2,5))*10000000) }");
         assertEval("{ cor(cbind(c(3,2,1), c(1,2,3))) }");
+        assertEval("{ cor(cbind(c(1, 1, 1), c(1, 1, 1))) }");
     }
 
     @Test
@@ -1784,7 +1785,6 @@ public class TestSimpleBuiltins extends TestBase {
     public void testCorIgnore() {
         assertEval("{ cor(cbind(c(1:9,0/0), 101:110)) }");
         assertEval("{ round( cor(cbind(c(10,5,4,1), c(2,5,10,5))), digits=5 ) }");
-        assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }");
     }
 
     @Test
-- 
GitLab