diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java index ca1cda7f78336c92bba9dfc5fa216436a663e481..a3e832046227da221745e0023bed1b3cac7d97fe 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Diag.java @@ -14,26 +14,26 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.asDoubleVector; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.complexValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue; -import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.gte0; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.logicalValue; import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.notIntNA; +import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.nullValue; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; -import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RComplex; import com.oracle.truffle.r.runtime.data.RDataFactory; +import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory; import com.oracle.truffle.r.runtime.data.RNull; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; @RBuiltin(name = "diag", kind = INTERNAL, parameterNames = {"x", "nrow", "ncol"}, behavior = PURE) public abstract class Diag extends RBuiltinNode.Arg3 { @@ -48,7 +48,6 @@ public abstract class Diag extends RBuiltinNode.Arg3 { private int checkX(RAbstractVector x, int nrow, int ncol) { int mn = (nrow < ncol) ? nrow : ncol; if (mn > 0 && x.getLength() == 0) { - CompilerDirectives.transferToInterpreter(); throw error(Message.POSITIVE_LENGTH, "x"); } return mn; @@ -56,8 +55,8 @@ public abstract class Diag extends RBuiltinNode.Arg3 { @Specialization protected Object diag(@SuppressWarnings("unused") RNull x, int nrow, int ncol) { - if (nrow == 0 && ncol == 0) { - return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{0, 0}); + if (nrow == 0 || ncol == 0) { + return RDataFactory.createDoubleVector(new double[]{}, true, new int[]{nrow, ncol}); } else { throw error(Message.X_NUMERIC); } @@ -74,43 +73,31 @@ public abstract class Diag extends RBuiltinNode.Arg3 { return RDataFactory.createDoubleVector(data, !RRuntime.isNA(x), new int[]{nrow, ncol}); } - @Specialization - protected Object diag(RAbstractComplexVector x, int nrow, int ncol) { + @Specialization(guards = "xAccess.supports(x)") + protected RAbstractVector diagCached(RAbstractVector x, int nrow, int ncol, + @Cached("x.access()") VectorAccess xAccess, + @Cached("createNew(xAccess.getType())") VectorAccess resultAccess, + @Cached("create()") VectorFactory factory) { int mn = checkX(x, nrow, ncol); - double[] data = new double[nrow * ncol * 2]; - int nx = x.getLength(); - for (int j = 0; j < mn; j++) { - RComplex value = x.getDataAt(j % nx); - int index = j * (nrow + 1) * 2; - data[index] = value.getRealPart(); - data[index + 1] = value.getImaginaryPart(); - } - return RDataFactory.createComplexVector(data, x.isComplete(), new int[]{nrow, ncol}); - } - - @Specialization - protected Object diag(RAbstractDoubleVector source, int nrow, int ncol) { - int mn = checkX(source, nrow, ncol); - - double[] data = new double[nrow * ncol]; - int nx = source.getLength(); - for (int j = 0; j < mn; j++) { - data[j * (nrow + 1)] = source.getDataAt(j % nx); + try (SequentialIterator xIter = xAccess.access(x)) { + RAbstractVector result = factory.createUninitializedVector(xAccess.getType(), nrow * ncol, new int[]{nrow, ncol}, null, null); + try (RandomIterator resultIter = resultAccess.randomAccess(result)) { + int resultIndex = 0; + for (int j = 0; j < mn; j++) { + xAccess.nextWithWrap(xIter); + resultAccess.setFromSameType(resultIter, resultIndex, xAccess, xIter); + resultIndex += nrow + 1; + } + } + result.setComplete(x.isComplete()); + return result; } - return RDataFactory.createDoubleVector(data, source.isComplete(), new int[]{nrow, ncol}); } - @Specialization - protected Object diag(RAbstractLogicalVector source, int nrow, int ncol) { - int mn = checkX(source, nrow, ncol); - - byte[] data = new byte[nrow * ncol]; - int nx = source.getLength(); - for (int j = 0; j < mn; j++) { - data[j * (nrow + 1)] = source.getDataAt(j % nx); - } - return RDataFactory.createLogicalVector(data, source.isComplete(), new int[]{nrow, ncol}); + @Specialization(replaces = "diagCached") + protected RAbstractVector diagGeneric(RAbstractVector x, int nrow, int ncol, + @Cached("create()") VectorFactory factory) { + return diagCached(x, nrow, ncol, x.slowPathAccess(), VectorAccess.createSlowPathNew(x.getRType()), factory); } - }