Skip to content
Snippets Groups Projects
Commit a97a4e9e authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert Cdist, Cutree and DoubleCentre to VectorAccess

parent ebc0383d
No related branches found
No related tags found
No related merge requests found
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
*/ */
package com.oracle.truffle.r.library.stats; package com.oracle.truffle.r.library.stats;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.instanceOf;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.runtime.nmath.MathConstants.DBL_MIN; import static com.oracle.truffle.r.runtime.nmath.MathConstants.DBL_MIN;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
...@@ -25,8 +25,6 @@ import com.oracle.truffle.r.nodes.attributes.SetAttributeNode; ...@@ -25,8 +25,6 @@ import com.oracle.truffle.r.nodes.attributes.SetAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetClassAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetClassAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
import com.oracle.truffle.r.runtime.data.nodes.ReadAccessor;
import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess;
import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
...@@ -34,9 +32,12 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector; ...@@ -34,9 +32,12 @@ import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RList;
import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RStringVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.ops.na.NACheck; import com.oracle.truffle.r.runtime.ops.na.NACheck;
public abstract class Cdist extends RExternalBuiltinNode.Arg4 { public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
private static final NACheck naCheck = NACheck.create(); private static final NACheck naCheck = NACheck.create();
@Child private GetFixedAttributeNode getNamesAttrNode = GetFixedAttributeNode.createNames(); @Child private GetFixedAttributeNode getNamesAttrNode = GetFixedAttributeNode.createNames();
...@@ -49,9 +50,10 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -49,9 +50,10 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
casts.arg(3).asDoubleVector().findFirst(); casts.arg(3).asDoubleVector().findFirst();
} }
@Specialization(guards = "method == cachedMethod") @Specialization(guards = {"method == cachedMethod", "xAccess.supports(x)"})
protected RDoubleVector cdist(RAbstractDoubleVector x, @SuppressWarnings("unused") int method, RList list, double p, @SuppressWarnings("unused") @Cached("method") int cachedMethod, protected RDoubleVector cdist(RAbstractDoubleVector x, @SuppressWarnings("unused") int method, RList list, double p,
@Cached("create()") VectorReadAccess.Double xAccess, @Cached("method") @SuppressWarnings("unused") int cachedMethod,
@Cached("x.access()") VectorAccess xAccess,
@Cached("getMethod(method)") Method methodObj, @Cached("getMethod(method)") Method methodObj,
@Cached("create()") SetAttributeNode setAttrNode, @Cached("create()") SetAttributeNode setAttrNode,
@Cached("create()") SetClassAttributeNode setClassAttrNode, @Cached("create()") SetClassAttributeNode setClassAttrNode,
...@@ -60,8 +62,10 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -60,8 +62,10 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
int nc = getDimNode.ncols(x); int nc = getDimNode.ncols(x);
int n = nr * (nr - 1) / 2; /* avoid int overflow for N ~ 50,000 */ int n = nr * (nr - 1) / 2; /* avoid int overflow for N ~ 50,000 */
double[] ans = new double[n]; double[] ans = new double[n];
RDoubleVector xm = x.materialize();
rdistance(new ReadAccessor.Double(x, xAccess), nr, nc, ans, false, methodObj, p); try (RandomIterator xIter = xAccess.randomAccess(x)) {
rdistance(xAccess, xIter, nr, nc, ans, false, methodObj, p);
}
RDoubleVector result = RDataFactory.createDoubleVector(ans, naCheck.neverSeenNA()); RDoubleVector result = RDataFactory.createDoubleVector(ans, naCheck.neverSeenNA());
DynamicObject resultAttrs = result.initAttributes(); DynamicObject resultAttrs = result.initAttributes();
...@@ -81,6 +85,14 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -81,6 +85,14 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
return result; return result;
} }
@Specialization(replaces = "cdist")
protected RDoubleVector cdistGeneric(RAbstractDoubleVector x, int method, RList list, double p,
@Cached("create()") SetAttributeNode setAttrNode,
@Cached("create()") SetClassAttributeNode setClassAttrNode,
@Cached("create()") GetDimAttributeNode getDimNode) {
return cdist(x, method, list, p, method, x.slowPathAccess(), getMethod(method), setAttrNode, setClassAttrNode, getDimNode);
}
private static boolean bothNonNAN(double a, double b) { private static boolean bothNonNAN(double a, double b) {
return !RRuntime.isNAorNaN(a) && !RRuntime.isNAorNaN(b); return !RRuntime.isNAorNaN(a) && !RRuntime.isNAorNaN(b);
} }
...@@ -96,7 +108,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -96,7 +108,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
return Method.values()[method - 1]; return Method.values()[method - 1];
} }
private void rdistance(ReadAccessor.Double xAccess, int nr, int nc, double[] d, boolean diag, Method method, double p) { private void rdistance(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, double[] d, boolean diag, Method method, double p) {
int ij; /* can exceed 2^31 - 1, but Java can't handle that */ int ij; /* can exceed 2^31 - 1, but Java can't handle that */
// //
if (method == Method.MINKOWSKI) { if (method == Method.MINKOWSKI) {
...@@ -109,7 +121,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -109,7 +121,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
naCheck.enable(true); naCheck.enable(true);
for (int j = 0; j <= nr; j++) { for (int j = 0; j <= nr; j++) {
for (int i = j + dc; i < nr; i++) { for (int i = j + dc; i < nr; i++) {
double r = method.dist(xAccess, nr, nc, i, j, p); double r = method.dist(xAccess, xIter, nr, nc, i, j, p);
naCheck.check(r); naCheck.check(r);
d[ij++] = r; d[ij++] = r;
} }
...@@ -119,7 +131,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -119,7 +131,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
public enum Method { public enum Method {
EUCLIDEAN { EUCLIDEAN {
@Override @Override
public double dist(ReadAccessor.Double xAccess, int nr, int nc, final int i1in, final int i2in, double p) { public double dist(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, final int i1in, final int i2in, double p) {
int i1 = i1in; int i1 = i1in;
int i2 = i2in; int i2 = i2in;
double dev; double dev;
...@@ -130,8 +142,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -130,8 +142,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
count = 0; count = 0;
dist = 0; dist = 0;
for (j = 0; j < nc; j++) { for (j = 0; j < nc; j++) {
if (bothNonNAN(xAccess.getDataAt(i1), xAccess.getDataAt(i2))) { if (bothNonNAN(xAccess.getDouble(xIter, i1), xAccess.getDouble(xIter, i2))) {
dev = (xAccess.getDataAt(i1) - xAccess.getDataAt(i2)); dev = (xAccess.getDouble(xIter, i1) - xAccess.getDouble(xIter, i2));
if (!RRuntime.isNAorNaN(dev)) { if (!RRuntime.isNAorNaN(dev)) {
dist += dev * dev; dist += dev * dev;
count++; count++;
...@@ -152,7 +164,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -152,7 +164,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
}, },
MAXIMUM { MAXIMUM {
@Override @Override
public double dist(ReadAccessor.Double xAccess, int nr, int nc, final int i1in, final int i2in, double p) { public double dist(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, final int i1in, final int i2in, double p) {
int i1 = i1in; int i1 = i1in;
int i2 = i2in; int i2 = i2in;
double dev; double dev;
...@@ -163,8 +175,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -163,8 +175,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
count = 0; count = 0;
dist = -Double.MAX_VALUE; dist = -Double.MAX_VALUE;
for (j = 0; j < nc; j++) { for (j = 0; j < nc; j++) {
if (bothNonNAN(xAccess.getDataAt(i1), xAccess.getDataAt(i2))) { if (bothNonNAN(xAccess.getDouble(xIter, i1), xAccess.getDouble(xIter, i2))) {
dev = Math.abs(xAccess.getDataAt(i1) - xAccess.getDataAt(i2)); dev = Math.abs(xAccess.getDouble(xIter, i1) - xAccess.getDouble(xIter, i2));
if (!RRuntime.isNAorNaN(dev)) { if (!RRuntime.isNAorNaN(dev)) {
if (dev > dist) { if (dev > dist) {
dist = dev; dist = dev;
...@@ -184,7 +196,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -184,7 +196,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
}, },
MANHATTAN { MANHATTAN {
@Override @Override
public double dist(ReadAccessor.Double xAccess, int nr, int nc, final int i1in, final int i2in, double p) { public double dist(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, final int i1in, final int i2in, double p) {
int i1 = i1in; int i1 = i1in;
int i2 = i2in; int i2 = i2in;
double dev; double dev;
...@@ -195,8 +207,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -195,8 +207,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
count = 0; count = 0;
dist = 0; dist = 0;
for (j = 0; j < nc; j++) { for (j = 0; j < nc; j++) {
if (bothNonNAN(xAccess.getDataAt(i1), xAccess.getDataAt(i2))) { if (bothNonNAN(xAccess.getDouble(xIter, i1), xAccess.getDouble(xIter, i2))) {
dev = Math.abs(xAccess.getDataAt(i1) - xAccess.getDataAt(i2)); dev = Math.abs(xAccess.getDouble(xIter, i1) - xAccess.getDouble(xIter, i2));
if (!RRuntime.isNAorNaN(dev)) { if (!RRuntime.isNAorNaN(dev)) {
dist += dev; dist += dev;
count++; count++;
...@@ -217,7 +229,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -217,7 +229,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
}, },
CANBERRA { CANBERRA {
@Override @Override
public double dist(ReadAccessor.Double xAccess, int nr, int nc, final int i1in, final int i2in, double p) { public double dist(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, final int i1in, final int i2in, double p) {
int i1 = i1in; int i1 = i1in;
int i2 = i2in; int i2 = i2in;
double dev; double dev;
...@@ -230,9 +242,9 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -230,9 +242,9 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
count = 0; count = 0;
dist = 0; dist = 0;
for (j = 0; j < nc; j++) { for (j = 0; j < nc; j++) {
if (bothNonNAN(xAccess.getDataAt(i1), xAccess.getDataAt(i2))) { if (bothNonNAN(xAccess.getDouble(xIter, i1), xAccess.getDouble(xIter, i2))) {
sum = Math.abs(xAccess.getDataAt(i1) + xAccess.getDataAt(i2)); sum = Math.abs(xAccess.getDouble(xIter, i1) + xAccess.getDouble(xIter, i2));
diff = Math.abs(xAccess.getDataAt(i1) - xAccess.getDataAt(i2)); diff = Math.abs(xAccess.getDouble(xIter, i1) - xAccess.getDouble(xIter, i2));
if (sum > DBL_MIN || diff > DBL_MIN) { if (sum > DBL_MIN || diff > DBL_MIN) {
dev = diff / sum; dev = diff / sum;
if (!RRuntime.isNAorNaN(dev) || if (!RRuntime.isNAorNaN(dev) ||
...@@ -258,7 +270,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -258,7 +270,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
}, },
BINARY { BINARY {
@Override @Override
public double dist(ReadAccessor.Double xAccess, int nr, int nc, final int i1in, final int i2in, double p) { public double dist(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, final int i1in, final int i2in, double p) {
int i1 = i1in; int i1 = i1in;
int i2 = i2in; int i2 = i2in;
int total; int total;
...@@ -271,13 +283,13 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -271,13 +283,13 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
dist = 0; dist = 0;
for (j = 0; j < nc; j++) { for (j = 0; j < nc; j++) {
if (bothNonNAN(xAccess.getDataAt(i1), xAccess.getDataAt(i2))) { if (bothNonNAN(xAccess.getDouble(xIter, i1), xAccess.getDouble(xIter, i2))) {
if (!bothFinite(xAccess.getDataAt(i1), xAccess.getDataAt(i2))) { if (!bothFinite(xAccess.getDouble(xIter, i1), xAccess.getDouble(xIter, i2))) {
RError.warning(RError.SHOW_CALLER2, RError.Message.GENERIC, "treating non-finite values as NA"); RError.warning(RError.SHOW_CALLER2, RError.Message.GENERIC, "treating non-finite values as NA");
} else { } else {
if (xAccess.getDataAt(i1) != 0. || xAccess.getDataAt(i2) != 0.) { if (xAccess.getDouble(xIter, i1) != 0. || xAccess.getDouble(xIter, i2) != 0.) {
count++; count++;
if (!(xAccess.getDataAt(i1) != 0. && xAccess.getDataAt(i2) != 0.)) { if (!(xAccess.getDouble(xIter, i1) != 0. && xAccess.getDouble(xIter, i2) != 0.)) {
dist++; dist++;
} }
} }
...@@ -300,7 +312,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -300,7 +312,7 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
}, },
MINKOWSKI { MINKOWSKI {
@Override @Override
public double dist(ReadAccessor.Double xAccess, int nr, int nc, final int i1in, final int i2in, double p) { public double dist(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, final int i1in, final int i2in, double p) {
int i1 = i1in; int i1 = i1in;
int i2 = i2in; int i2 = i2in;
double dev; double dev;
...@@ -311,8 +323,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -311,8 +323,8 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
count = 0; count = 0;
dist = 0; dist = 0;
for (j = 0; j < nc; j++) { for (j = 0; j < nc; j++) {
if (bothNonNAN(xAccess.getDataAt(i1), xAccess.getDataAt(i2))) { if (bothNonNAN(xAccess.getDouble(xIter, i1), xAccess.getDouble(xIter, i2))) {
dev = (xAccess.getDataAt(i1) - xAccess.getDataAt(i2)); dev = (xAccess.getDouble(xIter, i1) - xAccess.getDouble(xIter, i2));
if (!RRuntime.isNAorNaN(dev)) { if (!RRuntime.isNAorNaN(dev)) {
dist += Math.pow(Math.abs(dev), p); dist += Math.pow(Math.abs(dev), p);
count++; count++;
...@@ -331,6 +343,6 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 { ...@@ -331,6 +343,6 @@ public abstract class Cdist extends RExternalBuiltinNode.Arg4 {
} }
}; };
public abstract double dist(ReadAccessor.Double xAccess, int nr, int nc, int i1, int i2, double p); public abstract double dist(VectorAccess xAccess, RandomIterator xIter, int nr, int nc, int i1, int i2, double p);
} }
} }
...@@ -20,7 +20,8 @@ import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; ...@@ -20,7 +20,8 @@ import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RIntVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
// translated from library/stats/src/hclust_utils.c // translated from library/stats/src/hclust_utils.c
...@@ -32,10 +33,10 @@ public abstract class Cutree extends RExternalBuiltinNode.Arg2 { ...@@ -32,10 +33,10 @@ public abstract class Cutree extends RExternalBuiltinNode.Arg2 {
casts.arg(1).mustNotBeMissing().mapIf(nullValue(), emptyIntegerVector()).asIntegerVector(); casts.arg(1).mustNotBeMissing().mapIf(nullValue(), emptyIntegerVector()).asIntegerVector();
} }
@Specialization @Specialization(guards = {"mergeAccess.supports(merge)", "whichAccess.supports(which)"})
protected RIntVector cutree(RAbstractIntVector merge, RAbstractIntVector which, protected RIntVector cutree(RAbstractIntVector merge, RAbstractIntVector which,
@Cached("create()") VectorReadAccess.Int mergeAccess, @Cached("merge.access()") VectorAccess mergeAccess,
@Cached("create()") VectorReadAccess.Int whichAccess, @Cached("which.access()") VectorAccess whichAccess,
@Cached("create()") GetDimAttributeNode getDimNode) { @Cached("create()") GetDimAttributeNode getDimNode) {
int whichLen = which.getLength(); int whichLen = which.getLength();
...@@ -59,92 +60,96 @@ public abstract class Cutree extends RExternalBuiltinNode.Arg2 { ...@@ -59,92 +60,96 @@ public abstract class Cutree extends RExternalBuiltinNode.Arg2 {
int[] z = new int[n]; int[] z = new int[n];
int[] iAns = new int[n * whichLen]; int[] iAns = new int[n * whichLen];
Object mergeStore = mergeAccess.getDataStore(merge); try (RandomIterator mergeIter = mergeAccess.randomAccess(merge); RandomIterator whichIter = whichAccess.randomAccess(which)) {
Object whichStore = whichAccess.getDataStore(which);
// for (k = 1; k <= n; k++) { // for (k = 1; k <= n; k++) {
for (k = 0; k < n; k++) { for (k = 0; k < n; k++) {
sing[k] = true; /* is k-th obs. still alone in cluster ? */ sing[k] = true; /* is k-th obs. still alone in cluster ? */
mNr[k] = 0; /* containing last merge-step number of k-th obs. */ mNr[k] = 0; /* containing last merge-step number of k-th obs. */
} }
for (k = 1; k <= n - 1; k++) { for (k = 1; k <= n - 1; k++) {
/* k-th merge, from n-k+1 to n-k atoms: (m1,m2) = merge[ k , ] */ /* k-th merge, from n-k+1 to n-k atoms: (m1,m2) = merge[ k , ] */
m1 = mergeAccess.getDataAt(merge, mergeStore, k - 1); m1 = mergeAccess.getInt(mergeIter, k - 1);
m2 = mergeAccess.getDataAt(merge, mergeStore, n - 1 + k - 1); m2 = mergeAccess.getInt(mergeIter, n - 1 + k - 1);
if (m1 < 0 && m2 < 0) { /* merging atoms [-m1] and [-m2] */ if (m1 < 0 && m2 < 0) { /* merging atoms [-m1] and [-m2] */
mNr[adj(-m1)] = mNr[adj(-m2)] = k; mNr[adj(-m1)] = mNr[adj(-m2)] = k;
sing[adj(-m1)] = sing[adj(-m2)] = false; sing[adj(-m1)] = sing[adj(-m2)] = false;
} else if (m1 < 0 || m2 < 0) { /* the other >= 0 */ } else if (m1 < 0 || m2 < 0) { /* the other >= 0 */
if (m1 < 0) { if (m1 < 0) {
j = -m1; j = -m1;
m1 = m2; m1 = m2;
} else { } else {
j = -m2; j = -m2;
}
/* merging atom j & cluster m1 */
for (l = 1; l <= n; l++) {
if (mNr[adj(l)] == m1) {
mNr[adj(l)] = k;
} }
} /* merging atom j & cluster m1 */
mNr[adj(j)] = k; for (l = 1; l <= n; l++) {
sing[adj(j)] = false; if (mNr[adj(l)] == m1) {
} else { /* both m1, m2 >= 0 */ mNr[adj(l)] = k;
for (l = 1; l <= n; l++) { }
if (mNr[adj(l)] == m1 || mNr[adj(l)] == m2) { }
mNr[adj(l)] = k; mNr[adj(j)] = k;
sing[adj(j)] = false;
} else { /* both m1, m2 >= 0 */
for (l = 1; l <= n; l++) {
if (mNr[adj(l)] == m1 || mNr[adj(l)] == m2) {
mNr[adj(l)] = k;
}
} }
} }
}
/* /*
* does this k-th merge belong to a desired group size which[j] ? if yes, find j (maybe * does this k-th merge belong to a desired group size which[j] ? if yes, find j
* multiple ones): * (maybe multiple ones):
*/ */
foundJ = false; foundJ = false;
for (j = 0; j < whichLen; j++) { for (j = 0; j < whichLen; j++) {
if (whichAccess.getDataAt(which, whichStore, j) == n - k) { if (whichAccess.getInt(whichIter, j) == n - k) {
if (!foundJ) { /* first match (and usually only one) */ if (!foundJ) { /* first match (and usually only one) */
foundJ = true; foundJ = true;
// for (l = 1; l <= n; l++) // for (l = 1; l <= n; l++)
for (l = 0; l < n; l++) { for (l = 0; l < n; l++) {
z[l] = 0; z[l] = 0;
} }
nclust = 0; nclust = 0;
mm = j * n; /* may want to copy this column of ans[] */ mm = j * n; /* may want to copy this column of ans[] */
for (l = 1, m1 = mm; l <= n; l++, m1++) { for (l = 1, m1 = mm; l <= n; l++, m1++) {
if (sing[adj(l)]) { if (sing[adj(l)]) {
iAns[m1] = ++nclust; iAns[m1] = ++nclust;
} else { } else {
if (z[adj(mNr[adj(l)])] == 0) { if (z[adj(mNr[adj(l)])] == 0) {
z[adj(mNr[adj(l)])] = ++nclust; z[adj(mNr[adj(l)])] = ++nclust;
}
iAns[m1] = z[adj(mNr[adj(l)])];
} }
iAns[m1] = z[adj(mNr[adj(l)])]; }
} else { /* found_j: another which[j] == n-k : copy column */
for (l = 1, m1 = j * n, m2 = mm; l <= n; l++, m1++, m2++) {
iAns[m1] = iAns[m2];
} }
} }
} else { /* found_j: another which[j] == n-k : copy column */ } /* if ( match ) */
for (l = 1, m1 = j * n, m2 = mm; l <= n; l++, m1++, m2++) { } /* for(j .. which[j] ) */
iAns[m1] = iAns[m2]; } /* for(k ..) {merge} */
}
/* Dealing with trivial case which[] = n : */
for (j = 0; j < whichLen; j++) {
if (whichAccess.getInt(whichIter, j) == n) {
for (l = 1, m1 = j * n; l <= n; l++, m1++) {
iAns[m1] = l;
} }
} /* if ( match ) */
} /* for(j .. which[j] ) */
} /* for(k ..) {merge} */
/* Dealing with trivial case which[] = n : */
for (j = 0; j < whichLen; j++) {
if (whichAccess.getDataAt(which, whichStore, j) == n) {
for (l = 1, m1 = j * n; l <= n; l++, m1++) {
iAns[m1] = l;
} }
} }
} }
RIntVector result = RDataFactory.createIntVector(iAns, RDataFactory.COMPLETE_VECTOR, new int[]{n, whichLen}); return RDataFactory.createIntVector(iAns, RDataFactory.COMPLETE_VECTOR, new int[]{n, whichLen});
return result; }
@Specialization(replaces = "cutree")
protected RIntVector cutreeGeneric(RAbstractIntVector merge, RAbstractIntVector which,
@Cached("create()") GetDimAttributeNode getDimNode) {
return cutree(merge, which, merge.slowPathAccess(), which.slowPathAccess(), getDimNode);
} }
private static int adj(int i) { private static int adj(int i) {
......
...@@ -10,56 +10,63 @@ ...@@ -10,56 +10,63 @@
*/ */
package com.oracle.truffle.r.library.stats; package com.oracle.truffle.r.library.stats;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.missingValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode;
import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RExternalBuiltinNode;
import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.data.RDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.nodes.SetDataAt; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorReuse;
public abstract class DoubleCentre extends RExternalBuiltinNode.Arg1 { public abstract class DoubleCentre extends RExternalBuiltinNode.Arg1 {
static { static {
Casts casts = new Casts(DoubleCentre.class); Casts casts = new Casts(DoubleCentre.class);
casts.arg(0).mustBe(missingValue().not()).mustBe(nullValue().not(), RError.Message.MACRO_CAN_BE_APPLIED_TO, "REAL()", "numeric", "NULL").asDoubleVector(); casts.arg(0).mustNotBeNull(RError.Message.MACRO_CAN_BE_APPLIED_TO, "REAL()", "numeric", "NULL").mustNotBeMissing().asDoubleVector();
} }
@Specialization @Specialization(guards = {"aAccess.supports(a)", "reuse.supports(a)"})
protected RDoubleVector doubleCentre(RAbstractDoubleVector aVecAbs, protected RAbstractDoubleVector doubleCentre(RAbstractDoubleVector a,
@Cached("create()") VectorReadAccess.Double aAccess, @Cached("a.access()") VectorAccess aAccess,
@Cached("create()") SetDataAt.Double aSetter, @Cached("createNonShared(a)") VectorReuse reuse,
@Cached("create()") GetDimAttributeNode getDimNode) { @Cached("create()") GetDimAttributeNode getDimNode) {
RDoubleVector aVec = aVecAbs.materialize(); int n = getDimNode.nrows(a);
int n = getDimNode.nrows(aVec);
Object aStore = aAccess.getDataStore(aVec); try (RandomIterator aIter = aAccess.randomAccess(a)) {
for (int i = 0; i < n; i++) { RAbstractDoubleVector result = reuse.getResult(a);
double sum = 0; VectorAccess resultAccess = reuse.access(result);
for (int j = 0; j < n; j++) { try (RandomIterator resultIter = resultAccess.randomAccess(result)) {
sum += aAccess.getDataAt(aVec, aStore, i + j * n); for (int i = 0; i < n; i++) {
} double sum = 0;
sum /= n; for (int j = 0; j < n; j++) {
for (int j = 0; j < n; j++) { sum += aAccess.getDouble(aIter, i + j * n);
double val = aAccess.getDataAt(aVec, aStore, i + j * n); }
aSetter.setDataAt(aVec, aStore, i + j * n, val - sum); sum /= n;
} for (int j = 0; j < n; j++) {
} resultAccess.setDouble(resultIter, i + j * n, aAccess.getDouble(aIter, i + j * n) - sum);
for (int j = 0; j < n; j++) { }
double sum = 0; }
for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) {
sum += aAccess.getDataAt(aVec, aStore, i + j * n); double sum = 0;
} for (int i = 0; i < n; i++) {
sum /= n; sum += resultAccess.getDouble(aIter, i + j * n);
for (int i = 0; i < n; i++) { }
double val = aAccess.getDataAt(aVec, aStore, i + j * n); sum /= n;
aSetter.setDataAt(aVec, aStore, i + j * n, val - sum); for (int i = 0; i < n; i++) {
resultAccess.setDouble(resultIter, i + j * n, resultAccess.getDouble(aIter, i + j * n) - sum);
}
}
} }
return result;
} }
return aVec; }
@Specialization(replaces = "doubleCentre")
protected RAbstractDoubleVector doubleCentreGeneric(RAbstractDoubleVector a,
@Cached("createNonSharedGeneric()") VectorReuse reuse,
@Cached("create()") GetDimAttributeNode getDimNode) {
return doubleCentre(a, a.slowPathAccess(), reuse, getDimNode);
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment