From 6df599b5b5b6ac3458afd7b96e4897671b3b5cb6 Mon Sep 17 00:00:00 2001
From: stepan <stepan.sindelar@oracle.com>
Date: Thu, 7 Dec 2017 18:40:06 +0100
Subject: [PATCH] Stricter argument validation in some stats externals

---
 .../src/com/oracle/truffle/r/library/stats/Cdist.java     | 5 +++++
 .../src/com/oracle/truffle/r/library/stats/Cutree.java    | 6 ++++++
 .../com/oracle/truffle/r/library/stats/DoubleCentre.java  | 2 +-
 .../r/nodes/attributes/SpecialAttributesFunctions.java    | 8 ++++++++
 4 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java
index 1d88093b56..72f4dd44d6 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java
@@ -26,6 +26,7 @@ import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAt
 import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetClassAttributeNode;
 import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
 import com.oracle.truffle.r.runtime.RError;
+import com.oracle.truffle.r.runtime.RError.Message;
 import com.oracle.truffle.r.runtime.RRuntime;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
 import com.oracle.truffle.r.runtime.data.RDoubleVector;
@@ -58,6 +59,10 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
                     @Cached("create()") SetAttributeNode setAttrNode,
                     @Cached("create()") SetClassAttributeNode setClassAttrNode,
                     @Cached("create()") GetDimAttributeNode getDimNode) {
+        if (!getDimNode.isMatrix(x)) {
+            // Note: otherwise array index out of bounds
+            throw error(Message.MUST_BE_SQUARE_MATRIX, "x");
+        }
         int nr = getDimNode.nrows(x);
         int nc = getDimNode.ncols(x);
         int n = nr * (nr - 1) / 2; /* avoid int overflow for N ~ 50,000 */
diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java
index 8c0052139f..2aaa7620ca 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java
@@ -17,6 +17,7 @@ import com.oracle.truffle.api.dsl.Cached;
 import com.oracle.truffle.api.dsl.Specialization;
 import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
 import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
+import com.oracle.truffle.r.runtime.RError.Message;
 import com.oracle.truffle.r.runtime.data.RDataFactory;
 import com.oracle.truffle.r.runtime.data.RIntVector;
 import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
@@ -50,6 +51,11 @@ public abstract class Cutree extends RExternalBuiltinNode.Arg2 {
         boolean foundJ;
 
         int n = getDimNode.nrows(merge) + 1;
+        if (!getDimNode.isSquareMatrix(merge)) {
+            // Note: otherwise array index out of bounds
+            throw error(Message.MUST_BE_SQUARE_MATRIX, "x");
+        }
+
         /*
          * The C code uses 1-based indices for the next three arrays and so set the int * value
          * behind the actual start of the array. To keep the logic equivalent, we call adj(k) on the
diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java
index 6a130dbfd8..ae906e10a4 100644
--- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java
+++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java
@@ -34,7 +34,7 @@ public abstract class DoubleCentre extends RExternalBuiltinNode.Arg1 {
                     @Cached("createNonShared(a)") VectorReuse reuse,
                     @Cached("create()") GetDimAttributeNode getDimNode) {
         int n = getDimNode.nrows(a);
-        if (!getDimNode.isMatrix(a) || n != a.getLength() / n) {
+        if (!getDimNode.isSquareMatrix(a)) {
             // Note: otherwise array index out of bounds
             throw error(Message.MUST_BE_SQUARE_MATRIX, "x");
         }
diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java
index 2e6c5c2649..bec22d6526 100644
--- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java
+++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java
@@ -662,6 +662,14 @@ public final class SpecialAttributesFunctions {
             return nullDimsProfile.profile(dims == null) ? false : dims.getLength() == 2;
         }
 
+        public final boolean isSquareMatrix(RAbstractVector vector) {
+            RIntVector dims = (RIntVector) execute(vector);
+            if (nullDimsProfile.profile(dims == null) || dims.getLength() < 2) {
+                return false;
+            }
+            return dims.getDataAt(0) == dims.getDataAt(1);
+        }
+
         @Specialization(insertBefore = "getAttrFromAttributable")
         protected Object getScalarVectorDims(@SuppressWarnings("unused") RScalarVector x) {
             return null;
-- 
GitLab