Skip to content
Snippets Groups Projects
Commit 1c50eb8f authored by Lukas Stadler's avatar Lukas Stadler
Browse files

convert Diag to use VectorAccess

parent 5a75d737
No related branches found
No related tags found
No related merge requests found
...@@ -14,26 +14,26 @@ package com.oracle.truffle.r.nodes.builtin.base; ...@@ -14,26 +14,26 @@ package com.oracle.truffle.r.nodes.builtin.base;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asDoubleVector; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asDoubleVector;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notIntNA; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notIntNA;
import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue;
import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE;
import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL;
import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode;
import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RError.Message;
import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.RRuntime;
import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.builtins.RBuiltin;
import com.oracle.truffle.r.runtime.data.RComplex;
import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RDataFactory;
import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory;
import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RNull;
import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector;
import com.oracle.truffle.r.runtime.data.model.RAbstractVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator;
import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator;
@RBuiltin(name = "diag", kind = INTERNAL, parameterNames = {"x", "nrow", "ncol"}, behavior = PURE) @RBuiltin(name = "diag", kind = INTERNAL, parameterNames = {"x", "nrow", "ncol"}, behavior = PURE)
public abstract class Diag extends RBuiltinNode.Arg3 { public abstract class Diag extends RBuiltinNode.Arg3 {
...@@ -48,7 +48,6 @@ public abstract class Diag extends RBuiltinNode.Arg3 { ...@@ -48,7 +48,6 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
private int checkX(RAbstractVector x, int nrow, int ncol) { private int checkX(RAbstractVector x, int nrow, int ncol) {
int mn = (nrow < ncol) ? nrow : ncol; int mn = (nrow < ncol) ? nrow : ncol;
if (mn > 0 && x.getLength() == 0) { if (mn > 0 && x.getLength() == 0) {
CompilerDirectives.transferToInterpreter();
throw error(Message.POSITIVE_LENGTH, "x"); throw error(Message.POSITIVE_LENGTH, "x");
} }
return mn; return mn;
...@@ -56,8 +55,8 @@ public abstract class Diag extends RBuiltinNode.Arg3 { ...@@ -56,8 +55,8 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
@Specialization @Specialization
protected Object diag(@SuppressWarnings("unused") RNull x, int nrow, int ncol) { protected Object diag(@SuppressWarnings("unused") RNull x, int nrow, int ncol) {
if (nrow == 0 && ncol == 0) { if (nrow == 0 || ncol == 0) {
return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{0, 0}); return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{nrow, ncol});
} else { } else {
throw error(Message.X_NUMERIC); throw error(Message.X_NUMERIC);
} }
...@@ -74,43 +73,31 @@ public abstract class Diag extends RBuiltinNode.Arg3 { ...@@ -74,43 +73,31 @@ public abstract class Diag extends RBuiltinNode.Arg3 {
return RDataFactory.createDoubleVector(data, !RRuntime.isNA(x), new int[]{nrow, ncol}); return RDataFactory.createDoubleVector(data, !RRuntime.isNA(x), new int[]{nrow, ncol});
} }
@Specialization @Specialization(guards = "xAccess.supports(x)")
protected Object diag(RAbstractComplexVector x, int nrow, int ncol) { protected RAbstractVector diagCached(RAbstractVector x, int nrow, int ncol,
@Cached("x.access()") VectorAccess xAccess,
@Cached("createNew(xAccess.getType())") VectorAccess resultAccess,
@Cached("create()") VectorFactory factory) {
int mn = checkX(x, nrow, ncol); int mn = checkX(x, nrow, ncol);
double[] data = new double[nrow * ncol * 2]; try (SequentialIterator xIter = xAccess.access(x)) {
int nx = x.getLength(); RAbstractVector result = factory.createUninitializedVector(xAccess.getType(), nrow * ncol, new int[]{nrow, ncol}, null, null);
for (int j = 0; j < mn; j++) { try (RandomIterator resultIter = resultAccess.randomAccess(result)) {
RComplex value = x.getDataAt(j % nx); int resultIndex = 0;
int index = j * (nrow + 1) * 2; for (int j = 0; j < mn; j++) {
data[index] = value.getRealPart(); xAccess.nextWithWrap(xIter);
data[index + 1] = value.getImaginaryPart(); resultAccess.setFromSameType(resultIter, resultIndex, xAccess, xIter);
} resultIndex += nrow + 1;
return RDataFactory.createComplexVector(data, x.isComplete(), new int[]{nrow, ncol}); }
} }
result.setComplete(x.isComplete());
@Specialization return result;
protected Object diag(RAbstractDoubleVector source, int nrow, int ncol) {
int mn = checkX(source, nrow, ncol);
double[] data = new double[nrow * ncol];
int nx = source.getLength();
for (int j = 0; j < mn; j++) {
data[j * (nrow + 1)] = source.getDataAt(j % nx);
} }
return RDataFactory.createDoubleVector(data, source.isComplete(), new int[]{nrow, ncol});
} }
@Specialization @Specialization(replaces = "diagCached")
protected Object diag(RAbstractLogicalVector source, int nrow, int ncol) { protected RAbstractVector diagGeneric(RAbstractVector x, int nrow, int ncol,
int mn = checkX(source, nrow, ncol); @Cached("create()") VectorFactory factory) {
return diagCached(x, nrow, ncol, x.slowPathAccess(), VectorAccess.createSlowPathNew(x.getRType()), factory);
byte[] data = new byte[nrow * ncol];
int nx = source.getLength();
for (int j = 0; j < mn; j++) {
data[j * (nrow + 1)] = source.getDataAt(j % nx);
}
return RDataFactory.createLogicalVector(data, source.isComplete(), new int[]{nrow, ncol});
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment