Skip to content
Snippets Groups Projects
Commit 3b456f5e authored by Zbynek Slajchrt's avatar Zbynek Slajchrt
Browse files

Diag builtin refactored

parent fe1cc11e
Branches
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@ import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.nodes.unary.CastDoubleNode;
import com.oracle.truffle.r.runtime.RError;
import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RDataFactory;
......@@ -37,7 +38,7 @@ public abstract class Diag extends RBuiltinNode {
@Override
protected void createCasts(CastBuilder casts) {
casts.arg("x").asVector();
casts.arg("x").mapIf(complexValue().not(), asDoubleVector());
casts.arg("nrow").asIntegerVector().findFirst().mustBe(notIntNA(), Message.INVALID_LARGE_NA_VALUE, "nrow").mustBe(gte0(), Message.INVALID_NEGATIVE_VALUE, "nrow");
......@@ -62,6 +63,17 @@ public abstract class Diag extends RBuiltinNode {
}
}
@Specialization
protected Object diag(double x, int nrow, int ncol) {
int mn = (nrow < ncol) ? nrow : ncol;
double[] data = new double[nrow * ncol];
for (int j = 0; j < mn; j++) {
data[j * (nrow + 1)] = x;
}
return RDataFactory.createDoubleVector(data, RRuntime.isNA(x), new int[]{nrow, ncol});
}
@Specialization
protected Object diag(RAbstractComplexVector x, int nrow, int ncol) {
int mn = checkX(x, nrow, ncol);
......@@ -77,12 +89,10 @@ public abstract class Diag extends RBuiltinNode {
return RDataFactory.createComplexVector(data, x.isComplete(), new int[]{nrow, ncol});
}
@Specialization(guards = "!isRAbstractComplexVector(x)")
protected Object diag(RAbstractVector x, int nrow, int ncol, //
@Cached("create()") CastDoubleNode cast) {
int mn = checkX(x, nrow, ncol);
@Specialization
protected Object diag(RAbstractDoubleVector source, int nrow, int ncol) {
int mn = checkX(source, nrow, ncol);
RAbstractDoubleVector source = (RAbstractDoubleVector) cast.execute(x);
double[] data = new double[nrow * ncol];
int nx = source.getLength();
for (int j = 0; j < mn; j++) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment