diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java index 1d88093b56227120073c643e15d546084b029479..72f4dd44d64b5c665a62ba0acd643d825b031e13 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cdist.java @@ -26,6 +26,7 @@ import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAt import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetClassAttributeNode; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDoubleVector; @@ -58,6 +59,10 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { @Cached("create()") SetAttributeNode setAttrNode, @Cached("create()") SetClassAttributeNode setClassAttrNode, @Cached("create()") GetDimAttributeNode getDimNode) { + if (!getDimNode.isMatrix(x)) { + // Note: otherwise array index out of bounds + throw error(Message.MUST_BE_SQUARE_MATRIX, "x"); + } int nr = getDimNode.nrows(x); int nc = getDimNode.ncols(x); int n = nr * (nr - 1) / 2; /* avoid int overflow for N ~ 50,000 */ diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java index 8c0052139f253fd4233b127aeefe7f0cedaeb1df..2aaa7620ca0545411d051f3674c42dce572a4c46 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/Cutree.java @@ -17,6 +17,7 @@ import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; @@ -50,6 +51,11 @@ public abstract class Cutree extends RExternalBuiltinNode.Arg2 { boolean foundJ; int n = getDimNode.nrows(merge) + 1; + if (!getDimNode.isSquareMatrix(merge)) { + // Note: otherwise array index out of bounds + throw error(Message.MUST_BE_SQUARE_MATRIX, "x"); + } + /* * The C code uses 1-based indices for the next three arrays and so set the int * value * behind the actual start of the array. To keep the logic equivalent, we call adj(k) on the diff --git a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java index 6a130dbfd843a22dff850fd08eb89f2a487ebc62..ae906e10a4f0103025a3da2aef20662aa5da1a77 100644 --- a/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java +++ b/com.oracle.truffle.r.library/src/com/oracle/truffle/r/library/stats/DoubleCentre.java @@ -34,7 +34,7 @@ public abstract class DoubleCentre extends RExternalBuiltinNode.Arg1 { @Cached("createNonShared(a)") VectorReuse reuse, @Cached("create()") GetDimAttributeNode getDimNode) { int n = getDimNode.nrows(a); - if (!getDimNode.isMatrix(a) || n != a.getLength() / n) { + if (!getDimNode.isSquareMatrix(a)) { // Note: otherwise array index out of bounds throw error(Message.MUST_BE_SQUARE_MATRIX, "x"); } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java index 2e6c5c264963c3b02f1eb8209d60905983f46399..bec22d65269c1232586d6a49194c628a18d0ba53 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java @@ -662,6 +662,14 @@ public final class SpecialAttributesFunctions { return nullDimsProfile.profile(dims == null) ? false : dims.getLength() == 2; } + public final boolean isSquareMatrix(RAbstractVector vector) { + RIntVector dims = (RIntVector) execute(vector); + if (nullDimsProfile.profile(dims == null) || dims.getLength() < 2) { + return false; + } + return dims.getDataAt(0) == dims.getDataAt(1); + } + @Specialization(insertBefore = "getAttrFromAttributable") protected Object getScalarVectorDims(@SuppressWarnings("unused") RScalarVector x) { return null;