From 8308f20726241ef099c46de83e655e9a28f7693a Mon Sep 17 00:00:00 2001 From: Michael Haupt <michael.haupt@oracle.com> Date: Thu, 30 Jan 2014 22:37:31 +0100 Subject: [PATCH] fix bug in cor implementation --- .../truffle/r/nodes/builtin/base/Cor.java | 2 +- .../truffle/r/nodes/builtin/base/Cov.java | 2 +- .../truffle/r/nodes/builtin/base/Covcor.java | 111 +++++++++++++++++- .../truffle/r/test/ExpectedTestOutput.test | 8 ++ .../oracle/truffle/r/test/all/AllTests.java | 10 +- .../truffle/r/test/failing/FailingTests.java | 5 - .../r/test/simple/TestSimpleBuiltins.java | 2 +- 7 files changed, 123 insertions(+), 17 deletions(-) diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java index 5fb1932a0c..5c7430e071 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cor.java @@ -37,6 +37,6 @@ public abstract class Cor extends RBuiltinNode { @Specialization @SuppressWarnings("unused") public RDoubleVector dimWithDimensions(RDoubleVector vector1, RMissing vector2) { - return Covcor.cor(vector1, vector1, false, this.getEncapsulatingSourceSection()); + return Covcor.cor(vector1, null, false, this.getEncapsulatingSourceSection()); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java index b765919ea7..c054f2581b 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Cov.java @@ -37,6 +37,6 @@ public abstract class Cov extends RBuiltinNode { @Specialization @SuppressWarnings("unused") public RDoubleVector dimWithDimensions(RDoubleVector vector1, RMissing vector2) { - return Covcor.cov(vector1, vector1, false, this.getEncapsulatingSourceSection()); + return Covcor.cov(vector1, null, false, this.getEncapsulatingSourceSection()); } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java index 73bb82050b..36bf738ef1 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/builtin/base/Covcor.java @@ -12,6 +12,7 @@ package com.oracle.truffle.r.nodes.builtin.base; import com.oracle.truffle.api.*; +import com.oracle.truffle.api.CompilerDirectives.SlowPath; import com.oracle.truffle.r.runtime.*; import com.oracle.truffle.r.runtime.data.*; import com.oracle.truffle.r.runtime.data.model.*; @@ -32,7 +33,7 @@ public class Covcor { return corcov(x, y, iskendall, false, source); } - @com.oracle.truffle.api.CompilerDirectives.SlowPath + @SlowPath private static RDoubleVector corcov(RDoubleVector x, RDoubleVector y, boolean iskendall, boolean cor, SourceSection source) { boolean ansmat; boolean naFail; @@ -51,7 +52,9 @@ public class Covcor { ncx = 1; } - if (isMatrix(y)) { + if (y == null) { + ncy = ncx; + } else if (isMatrix(y)) { if (nrows(y) != n) { error("incompatible dimensions"); } @@ -82,8 +85,13 @@ public class Covcor { double[] ym = new double[ncy]; RIntVector ind = RDataFactory.createIntVector(n); - complete2(n, ncx, ncy, x, y, ind, naFail); - sd0 = covComplete2(n, ncx, ncy, x, y, xm, ym, ind, answerData, cor, iskendall); + if (y == null) { + complete1(n, ncx, x, ind, naFail); + sd0 = covComplete1(n, ncx, x, xm, ind, answerData, cor, iskendall); + } else { + complete2(n, ncx, ncy, x, y, ind, naFail); + sd0 = covComplete2(n, ncx, ncy, x, y, xm, ym, ind, answerData, cor, iskendall); + } if (sd0) { /* only in cor() */ RError.warning(source, RError.SD_ZERO); @@ -108,6 +116,26 @@ public class Covcor { return x.getDimensions()[0]; } + private static void complete1(int n, int ncx, RDoubleVector x, RIntVector ind, boolean naFail) { + int i; + int j; + for (i = 0; i < n; i++) { + ind.updateDataAt(i, 1, check); + } + for (j = 0; j < ncx; j++) { + // z = &x[j * n]; + for (i = 0; i < n; i++) { + if (Double.isNaN(x.getDataAt(j * n + i))) { + if (naFail) { + error("missing observations in cov/cor"); + } else { + ind.updateDataAt(i, 0, check); + } + } + } + } + } + private static void complete2(int n, int ncx, int ncy, RDoubleVector x, RDoubleVector y, RIntVector ind, boolean naFail) { int i; int j; @@ -141,6 +169,81 @@ public class Covcor { } } + private static boolean covComplete1(int n, int ncx, RDoubleVector x, double[] xm, RIntVector indInput, double[] ans, boolean cor, boolean kendall) { + int n1 = -1; + int nobs; + boolean isSd0 = false; + + /* total number of complete observations */ + nobs = 0; + for (int k = 0; k < n; k++) { + if (indInput.getDataAt(k) != 0) { + nobs++; + } + } + if (nobs <= 1) { /* too many missing */ + for (int i = 0; i < ans.length; i++) { + ans[i] = RRuntime.DOUBLE_NA; + } + return isSd0; + } + + RIntVector ind = indInput; + if (nobs == ind.getLength()) { + // No values of ind are zeroed. + ind = null; + } + + if (!kendall) { + mean(x, xm, ind, n, ncx, nobs); + n1 = nobs - 1; + } + for (int i = 0; i < ncx; i++) { + if (!kendall) { + double xxm = xm[i]; + for (int j = 0; j <= i; j++) { + double yym = xm[j]; + double sum = 0.0; + for (int k = 0; k < n; k++) { + if (ind == null || ind.getDataAt(k) != 0) { + sum += (x.getDataAt(i * n + k) - xxm) * (x.getDataAt(j * n + k) - yym); + } + } + double r = sum / n1; + ans[i + j * ncx] = r; + ans[j + i * ncx] = r; + } + } else { /* Kendall's tau */ + throw new UnsupportedOperationException("kendall's unsupported"); + } + } + + if (cor) { + for (int i = 0; i < ncx; i++) { + xm[i] = Math.sqrt(ans[i + i * ncx]); + } + for (int i = 0; i < ncx; i++) { + for (int j = 0; j < i; j++) { + if (xm[i] == 0 || xm[j] == 0) { + isSd0 = true; + ans[i + j * ncx] = RRuntime.DOUBLE_NA; + ans[j + i * ncx] = RRuntime.DOUBLE_NA; + } else { + double sum = ans[i + j * ncx] / (xm[i] * xm[j]); + if (sum > 1.0) { + sum = 1.0; + } + ans[i + j * ncx] = sum; + ans[j + i * ncx] = sum; + } + } + ans[i + i * ncx] = 1.0; + } + } + + return isSd0; + } + private static boolean covComplete2(int n, int ncx, int ncy, RDoubleVector x, RDoubleVector y, double[] xm, double[] ym, RIntVector indInput, double[] ans, boolean cor, boolean kendall) { int n1 = -1; int nobs; diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test index 18c6f6335e..c95e78ca32 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/ExpectedTestOutput.test @@ -3654,6 +3654,14 @@ a1 a2 #{ cor(c(1,2,3),c(1,2,3)) } [1] 1 +##com.oracle.truffle.r.test.simple.TestSimpleBuiltins.testCor +#{ cor(cbind(c(1,1,1), c(1,1,1))) } + [,1] [,2] +[1,] 1 NA +[2,] NA 1 +Warning message: +In cor(cbind(c(1, 1, 1), c(1, 1, 1))) : the standard deviation is zero + ##com.oracle.truffle.r.test.simple.TestSimpleBuiltins.testCor #{ cor(cbind(c(3,2,1), c(1,2,3))) } [,1] [,2] diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java index b761810493..b290d2b34a 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/all/AllTests.java @@ -5548,6 +5548,11 @@ public class AllTests extends TestBase { assertEval("{ cor(cbind(c(3,2,1), c(1,2,3))) }"); } + @Test + public void TestSimpleBuiltins_testCor_a9306c39144eb95725c311083b3248ba() { + assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }"); + } + @Test public void TestSimpleBuiltins_testCorIgnore_564c5ee2d2eea4a4b168dca5e6fa9e4f() { assertEval("{ cor(cbind(c(1:9,0/0), 101:110)) }"); @@ -5558,11 +5563,6 @@ public class AllTests extends TestBase { assertEval("{ round( cor(cbind(c(10,5,4,1), c(2,5,10,5))), digits=5 ) }"); } - @Test - public void TestSimpleBuiltins_testCorIgnore_a9306c39144eb95725c311083b3248ba() { - assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }"); - } - @Test public void TestSimpleBuiltins_testCov_4b96d1c7c503defdec6ebab5b659625c() { assertEval("{ cov(c(1,2,3),c(1,2,3)) }"); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java index db64573384..4946aa2d34 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/failing/FailingTests.java @@ -1798,11 +1798,6 @@ public class FailingTests extends TestBase { assertEval("{ round( cor(cbind(c(10,5,4,1), c(2,5,10,5))), digits=5 ) }"); } - @Ignore - public void TestSimpleBuiltins_testCorIgnore_a9306c39144eb95725c311083b3248ba() { - assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }"); - } - @Ignore public void TestSimpleBuiltins_testCrossprod_7f9549017d66ad3dd1583536fa7183d7() { assertEval("{ x <- 1:6 ; crossprod(x) }"); diff --git a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java index b49278cdf2..80a1d34b20 100644 --- a/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java +++ b/com.oracle.truffle.r.test/src/com/oracle/truffle/r/test/simple/TestSimpleBuiltins.java @@ -1777,6 +1777,7 @@ public class TestSimpleBuiltins extends TestBase { assertEval("{ cor(c(1,2,3),c(1,2,3)) }"); assertEval("{ as.integer(cor(c(1,2,3),c(1,2,5))*10000000) }"); assertEval("{ cor(cbind(c(3,2,1), c(1,2,3))) }"); + assertEval("{ cor(cbind(c(1, 1, 1), c(1, 1, 1))) }"); } @Test @@ -1784,7 +1785,6 @@ public class TestSimpleBuiltins extends TestBase { public void testCorIgnore() { assertEval("{ cor(cbind(c(1:9,0/0), 101:110)) }"); assertEval("{ round( cor(cbind(c(10,5,4,1), c(2,5,10,5))), digits=5 ) }"); - assertEval("{ cor(cbind(c(1,1,1), c(1,1,1))) }"); } @Test -- GitLab