Skip to content
Snippets Groups Projects
Commit 1c50eb8f authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert Diag to use VectorAccess

parent 5a75d737
No related branches found
No related tags found
No related merge requests found
......@@ -14,26 +14,26 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asDoubleVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notIntNA;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory;
import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
@RBuiltin(name = "diag", kind = INTERNAL, parameterNames = {"x", "nrow", "ncol"}, behavior = PURE)
public abstract class Diag extends RBuiltinNode.Arg3 {
......@@ -48,7 +48,6 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
private int checkX(RAbstractVector x, int nrow, int ncol) {
int mn = (nrow < ncol) ? nrow : ncol;
if (mn > 0 && x.getLength() == 0) {
CompilerDirectives.transferToInterpreter();
throw error(Message.POSITIVE_LENGTH, "x");
}
return mn;
......@@ -56,8 +55,8 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
@Specialization
protected Object diag(@SuppressWarnings("unused") RNull x, int nrow, int ncol) {
if (nrow == 0 && ncol == 0) {
return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{0, 0});
if (nrow == 0 || ncol == 0) {
return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{nrow, ncol});
} else {
throw error(Message.X_NUMERIC);
}
......@@ -74,43 +73,31 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
return RDataFactory.createDoubleVector(data, !RRuntime.isNA(x), new int[]{nrow, ncol});
}
@Specialization
protected Object diag(RAbstractComplexVector x, int nrow, int ncol) {
@Specialization(guards = "xAccess.supports(x)")
protected RAbstractVector diagCached(RAbstractVector x, int nrow, int ncol,
@Cached("x.access()") VectorAccess xAccess,
@Cached("createNew(xAccess.getType())") VectorAccess resultAccess,
@Cached("create()") VectorFactory factory) {
int mn = checkX(x, nrow, ncol);
double[] data = new double[nrow * ncol * 2];
int nx = x.getLength();
for (int j = 0; j < mn; j++) {
RComplex value = x.getDataAt(j % nx);
int index = j * (nrow + 1) * 2;
data[index] = value.getRealPart();
data[index + 1] = value.getImaginaryPart();
}
return RDataFactory.createComplexVector(data, x.isComplete(), new int[]{nrow, ncol});
}
@Specialization
protected Object diag(RAbstractDoubleVector source, int nrow, int ncol) {
int mn = checkX(source, nrow, ncol);
double[] data = new double[nrow * ncol];
int nx = source.getLength();
for (int j = 0; j < mn; j++) {
data[j * (nrow + 1)] = source.getDataAt(j % nx);
try (SequentialIterator xIter = xAccess.access(x)) {
RAbstractVector result = factory.createUninitializedVector(xAccess.getType(), nrow * ncol, new int[]{nrow, ncol}, null, null);
try (RandomIterator resultIter = resultAccess.randomAccess(result)) {
int resultIndex = 0;
for (int j = 0; j < mn; j++) {
xAccess.nextWithWrap(xIter);
resultAccess.setFromSameType(resultIter, resultIndex, xAccess, xIter);
resultIndex += nrow + 1;
}
}
result.setComplete(x.isComplete());
return result;
}
return RDataFactory.createDoubleVector(data, source.isComplete(), new int[]{nrow, ncol});
}
@Specialization
protected Object diag(RAbstractLogicalVector source, int nrow, int ncol) {
int mn = checkX(source, nrow, ncol);
byte[] data = new byte[nrow * ncol];
int nx = source.getLength();
for (int j = 0; j < mn; j++) {
data[j * (nrow + 1)] = source.getDataAt(j % nx);
}
return RDataFactory.createLogicalVector(data, source.isComplete(), new int[]{nrow, ncol});
@Specialization(replaces = "diagCached")
protected RAbstractVector diagGeneric(RAbstractVector x, int nrow, int ncol,
@Cached("create()") VectorFactory factory) {
return diagCached(x, nrow, ncol, x.slowPathAccess(), VectorAccess.createSlowPathNew(x.getRType()), factory);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment