diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Array.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Array.java index 533a4b2caa36879ed71a8fe176522ab6b9c50ee8..f23c6fd39aecbd44a4aa4cfbec5c28731bc57317 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Array.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Array.java @@ -30,46 +30,27 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; -import com.oracle.truffle.r.runtime.RRuntime; -import com.oracle.truffle.r.runtime.Utils; +import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RComplexVector; -import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.RDoubleVector; -import com.oracle.truffle.r.runtime.data.RIntVector; +import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory; import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RLogicalVector; import com.oracle.truffle.r.runtime.data.RNull; -import com.oracle.truffle.r.runtime.data.RRawVector; -import com.oracle.truffle.r.runtime.data.RStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; +import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; -/** - * The {@code} .Internal part of the {@code array} function. The R code may alter the arguments - * before calling {@code .Internal}. - * - * <pre> - * array <- function(data = NA, dim = length(data), dimnames = NULL) { .Internal.array(data, dim, dimnames) } - * </pre> - * - * TODO complete. This is sufficient for the b25 benchmark use. - */ @RBuiltin(name = "array", kind = INTERNAL, parameterNames = {"data", "dim", "dimnames"}, behavior = PURE) public abstract class Array extends RBuiltinNode.Arg3 { @Child private UpdateDimNames updateDimNames; - private final ConditionProfile nonEmptyVectorProfile = ConditionProfile.createBinaryProfile(); // it's OK for the following method to update dimnames in-place as the container is "fresh" private void updateDimNames(RAbstractContainer container, Object o) { @@ -87,227 +68,76 @@ public abstract class Array extends RBuiltinNode.Arg3 { casts.arg("dimnames").defaultError(RError.Message.DIMNAMES_LIST).allowNull().mustBe(instanceOf(RList.class)); } - private int dimDataHelper(RAbstractIntVector dim, int[] dimData) { + @Specialization(guards = {"dataAccess.supports(data)", "dimAccess.supports(dim)"}) + protected RAbstractVector arrayCached(RAbstractVector data, RAbstractIntVector dim, Object dimNames, + @Cached("data.access()") VectorAccess dataAccess, + @Cached("dim.access()") VectorAccess dimAccess, + @Cached("createNew(dataAccess.getType())") VectorAccess resultAccess, + @Cached("createBinaryProfile()") ConditionProfile hasDimNames, + @Cached("createBinaryProfile()") ConditionProfile isEmpty, + @Cached("create()") VectorFactory factory) { + // extract dimensions and compute total length + int[] dimArray; int totalLength = 1; - int seenNegative = 0; - for (int i = 0; i < dim.getLength(); i++) { - dimData[i] = dim.getDataAt(i); - if (dimData[i] < 0) { - seenNegative++; + boolean negativeDims = false; + try (SequentialIterator dimIter = dimAccess.access(dim)) { + dimArray = new int[dimAccess.getLength(dimIter)]; + while (dimAccess.next(dimIter)) { + int dimValue = dimAccess.getInt(dimIter); + if (dimValue < 0) { + negativeDims = true; + } + totalLength *= dimValue; + dimArray[dimIter.getIndex()] = dimValue; } - totalLength *= dimData[i]; } - if (seenNegative == dim.getLength() && seenNegative != 0) { - throw error(RError.Message.DIMS_CONTAIN_NEGATIVE_VALUES); - } else if (seenNegative > 0) { + if (totalLength < 0) { throw error(RError.Message.NEGATIVE_LENGTH_VECTORS_NOT_ALLOWED); + } else if (negativeDims) { + throw error(RError.Message.DIMS_CONTAIN_NEGATIVE_VALUES); } - return totalLength; - } - - private RIntVector doArrayInt(RAbstractIntVector vec, RAbstractIntVector dim) { - int[] dimData = new int[dim.getLength()]; - int totalLength = dimDataHelper(dim, dimData); - int[] data = new int[totalLength]; - int vecLength = vec.getLength(); - if (nonEmptyVectorProfile.profile(totalLength > 0 && vecLength > 0)) { - for (int i = 0; i < totalLength; i++) { - data[i] = vec.getDataAt(i % vec.getLength()); - } - return RDataFactory.createIntVector(data, vec.isComplete(), dimData); - } else { - for (int i = 0; i < totalLength; i++) { - data[i] = RRuntime.INT_NA; - } - return RDataFactory.createIntVector(data, RDataFactory.INCOMPLETE_VECTOR, dimData); - } - } - - @Specialization - protected RIntVector doArrayNoDimNames(RAbstractIntVector vec, RAbstractIntVector dim, @SuppressWarnings("unused") RNull dimnames) { - return doArrayInt(vec, dim); - } - - @Specialization - protected RIntVector doArray(RAbstractIntVector vec, RAbstractIntVector dim, RList dimnames) { - RIntVector ret = doArrayInt(vec, dim); - updateDimNames(ret, dimnames); - return ret; - } - - private RDoubleVector doArrayDouble(RAbstractDoubleVector vec, RAbstractIntVector dim) { - int[] dimData = new int[dim.getLength()]; - int totalLength = dimDataHelper(dim, dimData); - double[] data = new double[totalLength]; - int vecLength = vec.getLength(); - if (totalLength > 0 && vecLength > 0) { - for (int i = 0; i < totalLength; i++) { - data[i] = vec.getDataAt(i % vec.getLength()); - } - return RDataFactory.createDoubleVector(data, vec.isComplete(), dimData); - } else { - for (int i = 0; i < totalLength; i++) { - data[i] = RRuntime.DOUBLE_NA; - } - return RDataFactory.createDoubleVector(data, RDataFactory.INCOMPLETE_VECTOR, dimData); - } - } - - @Specialization - protected RDoubleVector doArrayNoDimNames(RAbstractDoubleVector vec, RAbstractIntVector dim, @SuppressWarnings("unused") RNull dimnames) { - return doArrayDouble(vec, dim); - } - - @Specialization - protected RDoubleVector doArray(RAbstractDoubleVector vec, RAbstractIntVector dim, RList dimnames) { - RDoubleVector ret = doArrayDouble(vec, dim); - updateDimNames(ret, dimnames); - return ret; - } - - private RLogicalVector doArrayLogical(RAbstractLogicalVector vec, RAbstractIntVector dim) { - int[] dimData = new int[dim.getLength()]; - int totalLength = dimDataHelper(dim, dimData); - byte[] data = new byte[totalLength]; - int vecLength = vec.getLength(); - if (totalLength > 0 && vecLength > 0) { - for (int i = 0; i < totalLength; i++) { - data[i] = vec.getDataAt(i % vec.getLength()); - } - return RDataFactory.createLogicalVector(data, vec.isComplete(), dimData); - } else { - for (int i = 0; i < totalLength; i++) { - data[i] = RRuntime.LOGICAL_NA; - } - return RDataFactory.createLogicalVector(data, RDataFactory.INCOMPLETE_VECTOR, dimData); - } - } - - @Specialization - protected RLogicalVector doArrayNoDimNames(RAbstractLogicalVector vec, RAbstractIntVector dim, @SuppressWarnings("unused") RNull dimnames) { - return doArrayLogical(vec, dim); - } - - @Specialization - protected RLogicalVector doArray(RAbstractLogicalVector vec, RAbstractIntVector dim, RList dimnames) { - RLogicalVector ret = doArrayLogical(vec, dim); - updateDimNames(ret, dimnames); - return ret; - } - - private RStringVector doArrayString(RAbstractStringVector vec, RAbstractIntVector dim) { - int[] dimData = new int[dim.getLength()]; - int totalLength = dimDataHelper(dim, dimData); - String[] data = new String[totalLength]; - int vecLength = vec.getLength(); - if (totalLength > 0 && vecLength > 0) { - for (int i = 0; i < totalLength; i++) { - data[i] = vec.getDataAt(i % vec.getLength()); - } - return RDataFactory.createStringVector(data, vec.isComplete(), dimData); - } else { - String empty = Utils.intern(""); - for (int i = 0; i < totalLength; i++) { - data[i] = empty; - } - return RDataFactory.createStringVector(data, RDataFactory.COMPLETE_VECTOR, dimData); - } - } - - @Specialization - protected RStringVector doArrayNoDimNames(RAbstractStringVector vec, RAbstractIntVector dim, @SuppressWarnings("unused") RNull dimnames) { - return doArrayString(vec, dim); - } - @Specialization - protected RStringVector doArray(RAbstractStringVector vec, RAbstractIntVector dim, RList dimnames) { - RStringVector ret = doArrayString(vec, dim); - updateDimNames(ret, dimnames); - return ret; - } - - private RComplexVector doArrayComplex(RAbstractComplexVector vec, RAbstractIntVector dim) { - int[] dimData = new int[dim.getLength()]; - int totalLength = dimDataHelper(dim, dimData); - double[] data = new double[totalLength << 1]; - int ind = 0; - int vecLength = vec.getLength(); - if (totalLength > 0 && vecLength > 0) { - for (int i = 0; i < totalLength; i++) { - RComplex d = vec.getDataAt(i % vec.getLength()); - data[ind++] = d.getRealPart(); - data[ind++] = d.getImaginaryPart(); - } - return RDataFactory.createComplexVector(data, vec.isComplete(), dimData); - } else { - for (int i = 0; i < totalLength; i++) { - data[ind++] = RRuntime.COMPLEX_NA_REAL_PART; - data[ind++] = RRuntime.COMPLEX_NA_IMAGINARY_PART; + RAbstractVector result = factory.createUninitializedVector(dataAccess.getType(), totalLength, dimArray, null, null); + + try (SequentialIterator resultIter = resultAccess.access(result); SequentialIterator dataIter = dataAccess.access(data)) { + if (isEmpty.profile(dataAccess.getLength(dataIter) == 0)) { + if (dataAccess.getType() == RType.Character) { + // character vectors are initialized with "" instead of NA + while (resultAccess.next(resultIter)) { + resultAccess.setString(resultIter, ""); + } + result.setComplete(true); + } else { + while (resultAccess.next(resultIter)) { + resultAccess.setNA(resultIter); + } + result.setComplete(false); + } + } else { + while (resultAccess.next(resultIter)) { + dataAccess.nextWithWrap(dataIter); + resultAccess.setFromSameType(resultIter, dataAccess, dataIter); + } + result.setComplete(!dataAccess.na.isEnabled()); } - return RDataFactory.createComplexVector(data, RDataFactory.INCOMPLETE_VECTOR, dimData); } - } - - @Specialization - protected RComplexVector doArrayNoDimNames(RAbstractComplexVector vec, RAbstractIntVector dim, @SuppressWarnings("unused") RNull dimnames) { - return doArrayComplex(vec, dim); - } - - @Specialization - protected RComplexVector doArray(RAbstractComplexVector vec, RAbstractIntVector dim, RList dimnames) { - RComplexVector ret = doArrayComplex(vec, dim); - updateDimNames(ret, dimnames); - return ret; - } - private RRawVector doArrayRaw(RAbstractRawVector vec, RAbstractIntVector dim) { - int[] dimData = new int[dim.getLength()]; - int totalLength = dimDataHelper(dim, dimData); - byte[] data = new byte[totalLength]; - int vecLength = vec.getLength(); - if (totalLength > 0 && vecLength > 0) { - for (int i = 0; i < totalLength; i++) { - data[i] = vec.getRawDataAt(i % vec.getLength()); - } + // dimensions are set as a separate step so they are checked for validity + if (hasDimNames.profile(dimNames instanceof RList)) { + updateDimNames(result, dimNames); } else { - for (int i = 0; i < totalLength; i++) { - data[i] = 0; - } - } - return RDataFactory.createRawVector(data, dimData); - } - - @Specialization - protected RRawVector doArrayNoDimNames(RAbstractRawVector vec, RAbstractIntVector dim, @SuppressWarnings("unused") RNull dimnames) { - return doArrayRaw(vec, dim); - } - - @Specialization - protected RRawVector doArray(RAbstractRawVector vec, RAbstractIntVector dim, RList dimnames) { - RRawVector ret = doArrayRaw(vec, dim); - updateDimNames(ret, dimnames); - return ret; - } - - private RList doArrayList(RList vec, RAbstractIntVector dim) { - int[] dimData = new int[dim.getLength()]; - int totalLength = dimDataHelper(dim, dimData); - Object[] data = new Object[totalLength]; - for (int i = 0; i < totalLength; i++) { - data[i] = vec.getDataAt(i % vec.getLength()); + assert dimNames instanceof RNull; } - return RDataFactory.createList(data, dimData); - } - - @Specialization - protected RList doArrayNoDimeNames(RList vec, RAbstractIntVector dim, @SuppressWarnings("unused") RNull dimnames) { - return doArrayList(vec, dim); + return result; } - @Specialization - protected RList doArray(RList vec, RAbstractIntVector dim, RList dimnames) { - RList ret = doArrayList(vec, dim); - updateDimNames(ret, dimnames); - return ret; + @Specialization(replaces = "arrayCached") + @TruffleBoundary + protected RAbstractVector arrayGeneric(RAbstractVector data, RAbstractIntVector dim, Object dimNames, + @Cached("createBinaryProfile()") ConditionProfile hasDimNames, + @Cached("createBinaryProfile()") ConditionProfile isEmpty, + @Cached("create()") VectorFactory factory) { + VectorAccess dataAccess = data.slowPathAccess(); + return arrayCached(data, dim, dimNames, dataAccess, dim.slowPathAccess(), VectorAccess.createSlowPathNew(dataAccess.getType()), hasDimNames, isEmpty, factory); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java index 05cbcbc30395a817ec2ee19a832c87018dced4ef..37d8da2fa8707921e44f38ef3b1093974fe2838c 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java @@ -27,46 +27,41 @@ import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.PRIMITIVE; import com.oracle.truffle.api.CompilerDirectives; -import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.ImportStatic; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.TruffleObject; -import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; -import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; +import com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; +import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.RList; +import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory; import com.oracle.truffle.r.runtime.data.RLogicalVector; import com.oracle.truffle.r.runtime.data.RNull; import com.oracle.truffle.r.runtime.data.RRaw; -import com.oracle.truffle.r.runtime.data.RRawVector; -import com.oracle.truffle.r.runtime.data.RStringVector; -import com.oracle.truffle.r.runtime.data.RTypedValue; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; @ImportStatic(RRuntime.class) @RBuiltin(name = "is.na", kind = PRIMITIVE, parameterNames = {"x"}, dispatch = INTERNAL_GENERIC, behavior = PURE) public abstract class IsNA extends RBuiltinNode.Arg1 { @Child private IsNA recursiveIsNA; - @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); + + @Child private VectorFactory factory = VectorFactory.create(); @Child private GetDimAttributeNode getDimsNode = GetDimAttributeNode.create(); - @Child private SetDimNamesAttributeNode setDimNamesNode = SetDimNamesAttributeNode.create(); + @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); @Child private GetDimNamesAttributeNode getDimNamesNode = GetDimNamesAttributeNode.create(); - private final ConditionProfile nullDimNamesProfile = ConditionProfile.createBinaryProfile(); - static { Casts.noCasts(IsNA.class); } @@ -86,93 +81,21 @@ public abstract class IsNA extends RBuiltinNode.Arg1 { return RRuntime.asLogical(RRuntime.isNA(value)); } - @Specialization - protected RLogicalVector isNA(RAbstractIntVector vector) { - byte[] resultVector = new byte[vector.getLength()]; - for (int i = 0; i < vector.getLength(); i++) { - resultVector[i] = RRuntime.asLogical(RRuntime.isNA(vector.getDataAt(i))); - } - return createResult(resultVector, vector); - } - @Specialization protected byte isNA(double value) { return RRuntime.asLogical(RRuntime.isNAorNaN(value)); } - @Specialization - protected RLogicalVector isNA(RAbstractDoubleVector vector) { - byte[] resultVector = new byte[vector.getLength()]; - for (int i = 0; i < vector.getLength(); i++) { - resultVector[i] = RRuntime.asLogical(RRuntime.isNAorNaN(vector.getDataAt(i))); - } - return createResult(resultVector, vector); - } - - @Specialization - protected RLogicalVector isNA(RComplexVector vector) { - byte[] resultVector = new byte[vector.getLength()]; - for (int i = 0; i < vector.getLength(); i++) { - RComplex complex = vector.getDataAt(i); - resultVector[i] = RRuntime.asLogical(RRuntime.isNA(complex)); - } - return createResult(resultVector, vector); - } - @Specialization protected byte isNA(String value) { return RRuntime.asLogical(RRuntime.isNA(value)); } - @Specialization - protected RLogicalVector isNA(RStringVector vector) { - byte[] resultVector = new byte[vector.getLength()]; - for (int i = 0; i < vector.getLength(); i++) { - resultVector[i] = RRuntime.asLogical(RRuntime.isNA(vector.getDataAt(i))); - } - return createResult(resultVector, vector); - } - - @Specialization - protected RLogicalVector isNA(RList list) { - byte[] resultVector = new byte[list.getLength()]; - for (int i = 0; i < list.getLength(); i++) { - Object result = isNARecursive(list.getDataAt(i)); - byte isNAResult; - if (result instanceof Byte) { - isNAResult = (Byte) result; - } else if (result instanceof RLogicalVector) { - RLogicalVector vector = (RLogicalVector) result; - // result is false unless that element is a length-one atomic vector - // and the single element of that vector is regarded as NA - isNAResult = (vector.getLength() == 1) ? vector.getDataAt(0) : RRuntime.LOGICAL_FALSE; - } else { - throw fail("unhandled return type in isNA(list)"); - } - resultVector[i] = isNAResult; - } - return RDataFactory.createLogicalVector(resultVector, RDataFactory.COMPLETE_VECTOR); - } - - @TruffleBoundary - private static UnsupportedOperationException fail(String message) { - throw new UnsupportedOperationException(message); - } - @Specialization protected byte isNA(byte value) { return RRuntime.asLogical(RRuntime.isNA(value)); } - @Specialization - protected RLogicalVector isNA(RLogicalVector vector) { - byte[] resultVector = new byte[vector.getLength()]; - for (int i = 0; i < vector.getLength(); i++) { - resultVector[i] = (RRuntime.isNA(vector.getDataAt(i)) ? RRuntime.LOGICAL_TRUE : RRuntime.LOGICAL_FALSE); - } - return createResult(resultVector, vector); - } - @Specialization protected byte isNA(RComplex value) { return RRuntime.asLogical(RRuntime.isNA(value)); @@ -183,19 +106,62 @@ public abstract class IsNA extends RBuiltinNode.Arg1 { return RRuntime.LOGICAL_FALSE; } - @Specialization - protected RLogicalVector isNA(RRawVector vector) { - byte[] resultVector = new byte[vector.getLength()]; - for (int i = 0; i < vector.getLength(); i++) { - resultVector[i] = RRuntime.LOGICAL_FALSE; + private RLogicalVector isNAVector(RAbstractVector vector, VectorAccess access) { + try (SequentialIterator iter = access.access(vector)) { + byte[] data = new byte[access.getLength(iter)]; + while (access.next(iter)) { + boolean isNA; + switch (access.getType()) { + case Double: + isNA = access.na.checkNAorNaN(access.getDouble(iter)); + break; + case Character: + case Complex: + case Integer: + case Logical: + isNA = access.isNA(iter); + break; + case Raw: + isNA = false; + break; + case List: + Object result = isNARecursive(access.getListElement(iter)); + if (result instanceof Byte) { + isNA = ((byte) result) == RRuntime.LOGICAL_TRUE; + } else if (result instanceof RLogicalVector) { + RLogicalVector recVector = (RLogicalVector) result; + // result is false unless that element is a length-one atomic vector + // and the single element of that vector is regarded as NA + isNA = (recVector.getLength() == 1) ? recVector.getDataAt(0) == RRuntime.LOGICAL_TRUE : false; + } else { + throw RInternalError.shouldNotReachHere("unhandled return type in isNA(list)"); + } + break; + default: + throw RInternalError.shouldNotReachHere(); + + } + data[iter.getIndex()] = RRuntime.asLogical(isNA); + } + return factory.createLogicalVector(data, RDataFactory.COMPLETE_VECTOR, getDimsNode.getDimensions(vector), getNamesNode.getNames(vector), getDimNamesNode.getDimNames(vector)); } - return createResult(resultVector, vector); + } + + @Specialization(guards = "access.supports(vector)") + protected RLogicalVector isNACached(RAbstractVector vector, + @Cached("vector.access()") VectorAccess access) { + return isNAVector(vector, access); + } + + @Specialization(replaces = "isNACached") + protected RLogicalVector isNAGeneric(RAbstractVector vector) { + return isNAVector(vector, vector.slowPathAccess()); } @Specialization protected RLogicalVector isNA(RNull value) { warning(RError.Message.IS_NA_TO_NON_VECTOR, value.getRType().getName()); - return RDataFactory.createEmptyLogicalVector(); + return factory.createEmptyLogicalVector(); } @Specialization(guards = "isForeignObject(obj)") @@ -203,20 +169,9 @@ public abstract class IsNA extends RBuiltinNode.Arg1 { return RRuntime.LOGICAL_FALSE; } - // Note: all the primitive values have specialization, so we can only get RTypedValue in - // fallback @Fallback protected byte isNA(Object value) { - warning(RError.Message.IS_NA_TO_NON_VECTOR, value instanceof RTypedValue ? ((RTypedValue) value).getRType().getName() : value); + warning(RError.Message.IS_NA_TO_NON_VECTOR, Predef.typeName().apply(value)); return RRuntime.LOGICAL_FALSE; } - - private RLogicalVector createResult(byte[] data, RAbstractVector originalVector) { - RLogicalVector result = RDataFactory.createLogicalVector(data, RDataFactory.COMPLETE_VECTOR, getDimsNode.getDimensions(originalVector), getNamesNode.getNames(originalVector)); - RList dimNames = getDimNamesNode.getDimNames(originalVector); - if (nullDimNamesProfile.profile(dimNames != null)) { - setDimNamesNode.setDimNames(result, dimNames); - } - return result; - } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RepeatInternal.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RepeatInternal.java index 9ece548f64e888b27287187d3d912401c354872f..0fade482ee9484506923c2d4d37feac2bc06bd62 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RepeatInternal.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/RepeatInternal.java @@ -28,29 +28,19 @@ import static com.oracle.truffle.r.nodes.builtin.CastBuilder.Predef.typeName; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; -import java.util.function.IntFunction; - +import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RComplexVector; -import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.RDoubleVector; -import com.oracle.truffle.r.runtime.data.RIntVector; -import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RLogicalVector; -import com.oracle.truffle.r.runtime.data.RRawVector; -import com.oracle.truffle.r.runtime.data.RStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; +import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory; +import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; @RBuiltin(name = "rep.int", kind = INTERNAL, parameterNames = {"x", "times"}, behavior = PURE) public abstract class RepeatInternal extends RBuiltinNode.Arg2 { @@ -64,136 +54,70 @@ public abstract class RepeatInternal extends RBuiltinNode.Arg2 { RError.Message.INVALID_VALUE, "times"); } - @FunctionalInterface - private interface ArrayUpdateFunction<ValueT, ArrayT> { - void update(ArrayT array, int pos, ValueT value, int index); - } + private RAbstractVector performRep(RAbstractVector value, RAbstractIntVector times, VectorFactory factory, VectorAccess valueAccess, VectorAccess timesAccess, VectorAccess resultAccess) { + try (SequentialIterator valueIter = valueAccess.access(value); SequentialIterator timesIter = timesAccess.access(times)) { + int valueLength = valueAccess.getLength(valueIter); + int timesLength = timesAccess.getLength(timesIter); - @FunctionalInterface - private interface CreateResultFunction<ResultT, ArrayT> { - ResultT create(ArrayT array, boolean complete); - } - - private <ValueT extends RAbstractVector, ResultT extends ValueT, ArrayT> ResultT repInt(ValueT value, RAbstractIntVector times, IntFunction<ArrayT> arrayConstructor, - ArrayUpdateFunction<ValueT, ArrayT> arrayUpdate, CreateResultFunction<ResultT, ArrayT> createResult) { - ArrayT result; - int timesLength = times.getLength(); - int valueLength = value.getLength(); - if (timesOneProfile.profile(timesLength == 1)) { - int timesValue = times.getDataAt(0); - if (timesValue < 0) { - throw error(RError.Message.INVALID_VALUE, "times"); - } - int count = timesValue * valueLength; - result = arrayConstructor.apply(count); - int pos = 0; - for (int i = 0; i < timesValue; i++) { - for (int j = 0; j < valueLength; j++) { - arrayUpdate.update(result, pos++, value, j); - } - } - } else if (timesLength == valueLength) { - int count = 0; - for (int i = 0; i < timesLength; i++) { - int data = times.getDataAt(i); - if (data < 0) { + RVector<?> result; + if (timesOneProfile.profile(timesLength == 1)) { + timesAccess.next(timesIter); + int timesValue = timesAccess.getInt(timesIter); + if (timesValue < 0) { throw error(RError.Message.INVALID_VALUE, "times"); } - count += data; - } - result = arrayConstructor.apply(count); - int pos = 0; - for (int i = 0; i < valueLength; i++) { - int num = times.getDataAt(i); - for (int j = 0; j < num; j++) { - arrayUpdate.update(result, pos++, value, i); - } - } - } else { - throw error(RError.Message.INVALID_VALUE, "times"); - } - return createResult.create(result, value.isComplete()); - } - - @Specialization - protected RDoubleVector repInt(RAbstractDoubleVector value, RAbstractIntVector times) { - return repInt(value, times, double[]::new, (array, pos, val, index) -> array[pos] = val.getDataAt(index), RDataFactory::createDoubleVector); - } - - @Specialization - protected RIntVector repInt(RAbstractIntVector value, RAbstractIntVector times) { - return repInt(value, times, int[]::new, (array, pos, val, index) -> array[pos] = val.getDataAt(index), RDataFactory::createIntVector); - } - - @Specialization - protected RLogicalVector repInt(RAbstractLogicalVector value, RAbstractIntVector times) { - return repInt(value, times, byte[]::new, (array, pos, val, index) -> array[pos] = val.getDataAt(index), RDataFactory::createLogicalVector); - } - - @Specialization - protected RStringVector repInt(RAbstractStringVector value, RAbstractIntVector times) { - return repInt(value, times, String[]::new, (array, pos, val, index) -> array[pos] = val.getDataAt(index), RDataFactory::createStringVector); - } - - @Specialization - protected RRawVector repInt(RAbstractRawVector value, RAbstractIntVector times) { - return repInt(value, times, byte[]::new, (array, pos, val, index) -> array[pos] = val.getRawDataAt(index), (array, complete) -> RDataFactory.createRawVector(array)); - } - - @Specialization - protected RComplexVector repComplex(RAbstractComplexVector value, RAbstractIntVector times) { - int timesLength = times.getLength(); - int valueLength = value.getLength(); - double[] resultArray; - if (timesOneProfile.profile(timesLength == 1)) { - int timesValue = times.getDataAt(0); - if (timesValue < 0) { - throw error(RError.Message.INVALID_VALUE, "times"); - } - resultArray = new double[(timesValue * valueLength) << 1]; - int pos = 0; - for (int i = 0; i < timesValue; i++) { - for (int j = 0; j < valueLength; j++) { - RComplex complex = value.getDataAt(j); - resultArray[pos++] = complex.getRealPart(); - resultArray[pos++] = complex.getImaginaryPart(); + result = factory.createVector(valueAccess.getType(), timesValue * valueLength, false); + try (SequentialIterator resultIter = resultAccess.access(result)) { + for (int i = 0; i < timesValue; i++) { + while (valueAccess.next(valueIter)) { + resultAccess.next(resultIter); + resultAccess.setFromSameType(resultIter, valueAccess, valueIter); + } + valueAccess.reset(valueIter); + } } - } - } else if (timesLength == valueLength) { - int count = 0; - for (int i = 0; i < timesLength; i++) { - int data = times.getDataAt(i); - if (data < 0) { - throw error(RError.Message.INVALID_VALUE, "times"); + } else if (timesLength == valueLength) { + int count = 0; + while (timesAccess.next(timesIter)) { + int num = timesAccess.getInt(timesIter); + if (num < 0) { + throw error(RError.Message.INVALID_VALUE, "times"); + } + count += num; } - count += data; - } - resultArray = new double[count << 1]; - int pos = 0; - for (int i = 0; i < valueLength; i++) { - int num = times.getDataAt(i); - RComplex complex = value.getDataAt(i); - for (int j = 0; j < num; j++) { - resultArray[pos++] = complex.getRealPart(); - resultArray[pos++] = complex.getImaginaryPart(); + result = factory.createVector(valueAccess.getType(), count, false); + + timesAccess.reset(timesIter); + try (SequentialIterator resultIter = resultAccess.access(result)) { + while (timesAccess.next(timesIter) && valueAccess.next(valueIter)) { + int num = timesAccess.getInt(timesIter); + for (int i = 0; i < num; i++) { + resultAccess.next(resultIter); + resultAccess.setFromSameType(resultIter, valueAccess, valueIter); + } + } } + } else { + throw error(RError.Message.INVALID_VALUE, "times"); } - } else { - throw error(RError.Message.INVALID_VALUE, "times"); + result.setComplete(!valueAccess.na.isEnabled()); + return result; } - return RDataFactory.createComplexVector(resultArray, value.isComplete()); } - @Specialization - protected RList repList(RList value, int times) { - int oldLength = value.getLength(); - int length = value.getLength() * times; - Object[] array = new Object[length]; - for (int i = 0; i < times; i++) { - for (int j = 0; j < oldLength; j++) { - array[i * oldLength + j] = value.getDataAt(j); - } - } - return RDataFactory.createList(array); + @Specialization(guards = {"valueAccess.supports(value)", "timesAccess.supports(times)"}) + protected RAbstractVector repCached(RAbstractVector value, RAbstractIntVector times, + @Cached("create()") VectorFactory factory, + @Cached("value.access()") VectorAccess valueAccess, + @Cached("times.access()") VectorAccess timesAccess, + @Cached("createNew(value.getRType())") VectorAccess resultAccess) { + return performRep(value, times, factory, valueAccess, timesAccess, resultAccess); + } + + @Specialization(replaces = "repCached") + @TruffleBoundary + protected RAbstractVector repGeneric(RAbstractVector value, RAbstractIntVector times, + @Cached("create()") VectorFactory factory) { + return performRep(value, times, factory, value.slowPathAccess(), times.slowPathAccess(), VectorAccess.createSlowPathNew(value.getRType())); } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java index c760f844bb6b9edefe7a99821f79213d08e2009f..3a288c89a58e532a5082356475c88742234dad6e 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Split.java @@ -27,24 +27,23 @@ import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; import java.util.Arrays; +import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; +import com.oracle.truffle.r.nodes.builtin.base.SplitNodeGen.GetSplitNamesNodeGen; import com.oracle.truffle.r.nodes.helpers.RFactorNodes; -import com.oracle.truffle.r.runtime.Utils; +import com.oracle.truffle.r.runtime.RError.Message; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RStringVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.SequentialIterator; +import com.oracle.truffle.r.runtime.nodes.RBaseNode; /** * The {@code split} internal. Internal version of 'split' is invoked from 'split.default' function @@ -60,11 +59,7 @@ import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; public abstract class Split extends RBuiltinNode.Arg2 { @Child private RFactorNodes.GetLevels getLevelNode = new RFactorNodes.GetLevels(); - @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); - - @SuppressWarnings("unused") private final ConditionProfile noStringLevels = ConditionProfile.createBinaryProfile(); - private final ConditionProfile namesProfile = ConditionProfile.createBinaryProfile(); - @Child private VectorReadAccess.Int factorAccess = VectorReadAccess.Int.create(); + @Child private GetSplitNames getSplitNames = GetSplitNamesNodeGen.create(); private static final int INITIAL_SIZE = 5; private static final int SCALE_FACTOR = 2; @@ -73,238 +68,234 @@ public abstract class Split extends RBuiltinNode.Arg2 { Casts.noCasts(Split.class); } - @Specialization - protected RList split(RAbstractListVector x, RAbstractIntVector f) { - Object fStore = factorAccess.getDataStore(f); - RStringVector names = getLevelNode.execute(f); - final int nLevels = getNLevels(names); - - // initialise result arrays - Object[][] collectResults = new Object[nLevels][]; - int[] collectResultSize = new int[nLevels]; - for (int i = 0; i < collectResults.length; i++) { - collectResults[i] = new Object[INITIAL_SIZE]; - } - - // perform split - int factorLen = f.getLength(); - for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factorLen)) { - int resultIndex = factorAccess.getDataAt(f, fStore, fi) - 1; // a factor is a 1-based - // int vector - Object[] collect = collectResults[resultIndex]; - if (collect.length == collectResultSize[resultIndex]) { - collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); - collect = collectResults[resultIndex]; + @Specialization(limit = "4", guards = {"xAccess.supports(x)", "fAccess.supports(f)"}) + protected RList split(RAbstractVector x, RAbstractIntVector f, + @Cached("x.access()") VectorAccess xAccess, + @Cached("f.access()") VectorAccess fAccess) { + try (SequentialIterator xIter = xAccess.access(x); SequentialIterator fIter = fAccess.access(f)) { + RStringVector names = getLevelNode.execute(f); + int nLevels = getNLevels(names); + int[] collectResultSize = new int[nLevels]; + Object[] results = new Object[nLevels]; + + switch (xAccess.getType()) { + case Character: { + // Initialize result arrays + String[][] collectResults = new String[nLevels][INITIAL_SIZE]; + + // perform split + while (xAccess.next(xIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + String[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = collect = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + } + collect[collectResultSize[resultIndex]++] = xAccess.getString(xIter); + } + + RStringVector[] resultNames = getSplitNames.getNames(x, fAccess, fIter, nLevels, collectResultSize); + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createStringVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), x.isComplete(), + (resultNames != null) ? resultNames[i] : null); + } + break; + } + case Complex: { + // Initialize result arrays + double[][] collectResults = new double[nLevels][INITIAL_SIZE * 2]; + + // perform split + while (xAccess.next(xIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + double[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex] * 2) { + collectResults[resultIndex] = collect = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + } + collect[collectResultSize[resultIndex] * 2] = xAccess.getComplexR(xIter); + collect[collectResultSize[resultIndex] * 2 + 1] = xAccess.getComplexI(xIter); + collectResultSize[resultIndex]++; + } + + RStringVector[] resultNames = getSplitNames.getNames(x, fAccess, fIter, nLevels, collectResultSize); + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createComplexVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i] * 2), x.isComplete(), + (resultNames != null) ? resultNames[i] : null); + } + break; + } + case Double: { + // Initialize result arrays + double[][] collectResults = new double[nLevels][INITIAL_SIZE]; + + // perform split + while (xAccess.next(xIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + double[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = collect = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + } + collect[collectResultSize[resultIndex]++] = xAccess.getDouble(xIter); + } + + RStringVector[] resultNames = getSplitNames.getNames(x, fAccess, fIter, nLevels, collectResultSize); + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createDoubleVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), x.isComplete(), + (resultNames != null) ? resultNames[i] : null); + } + break; + } + case Integer: { + // Initialize result arrays + int[][] collectResults = new int[nLevels][INITIAL_SIZE]; + + // perform split + while (xAccess.next(xIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + int[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = collect = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + } + collect[collectResultSize[resultIndex]++] = xAccess.getInt(xIter); + } + + RStringVector[] resultNames = getSplitNames.getNames(x, fAccess, fIter, nLevels, collectResultSize); + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createIntVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), x.isComplete(), + (resultNames != null) ? resultNames[i] : null); + } + break; + } + case List: { + // Initialize result arrays + Object[][] collectResults = new Object[nLevels][INITIAL_SIZE]; + + // perform split + while (xAccess.next(xIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + Object[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = collect = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + } + collect[collectResultSize[resultIndex]++] = xAccess.getListElement(xIter); + } + + RStringVector[] resultNames = getSplitNames.getNames(x, fAccess, fIter, nLevels, collectResultSize); + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createList(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), + (resultNames != null) ? resultNames[i] : null); + } + break; + } + case Logical: { + // Initialize result arrays + byte[][] collectResults = new byte[nLevels][INITIAL_SIZE]; + + // perform split + while (xAccess.next(xIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + byte[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = collect = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + } + collect[collectResultSize[resultIndex]++] = xAccess.getLogical(xIter); + } + + RStringVector[] resultNames = getSplitNames.getNames(x, fAccess, fIter, nLevels, collectResultSize); + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createLogicalVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), x.isComplete(), + (resultNames != null) ? resultNames[i] : null); + } + break; + } + case Raw: { + // Initialize result arrays + byte[][] collectResults = new byte[nLevels][INITIAL_SIZE]; + + // perform split + while (xAccess.next(xIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + byte[] collect = collectResults[resultIndex]; + if (collect.length == collectResultSize[resultIndex]) { + collectResults[resultIndex] = collect = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); + } + collect[collectResultSize[resultIndex]++] = xAccess.getRaw(xIter); + } + + RStringVector[] resultNames = getSplitNames.getNames(x, fAccess, fIter, nLevels, collectResultSize); + for (int i = 0; i < nLevels; i++) { + results[i] = RDataFactory.createRawVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), + (resultNames != null) ? resultNames[i] : null); + } + break; + } + default: + throw error(Message.UNIMPLEMENTED_TYPE_IN_FUNCTION, xAccess.getType().getName(), "split"); } - collect[collectResultSize[resultIndex]++] = x.getDataAt(i); - } - - Object[] results = new Object[nLevels]; - RStringVector[] resultNames = getNames(x, f, fStore, nLevels, collectResultSize); - for (int i = 0; i < nLevels; i++) { - results[i] = RDataFactory.createList(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), - (resultNames != null) ? resultNames[i] : null); - } - return RDataFactory.createList(results, names); - } - - @Specialization - protected RList split(RAbstractIntVector x, RAbstractIntVector f) { - Object fStore = factorAccess.getDataStore(f); - RStringVector names = getLevelNode.execute(f); - final int nLevels = getNLevels(names); - - // initialise result arrays - int[][] collectResults = new int[nLevels][]; - int[] collectResultSize = new int[nLevels]; - for (int i = 0; i < collectResults.length; i++) { - collectResults[i] = new int[INITIAL_SIZE]; - } - - // perform split - int factorLen = f.getLength(); - for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factorLen)) { - int resultIndex = factorAccess.getDataAt(f, fStore, fi) - 1; // a factor is a 1-based - // int vector - int[] collect = collectResults[resultIndex]; - if (collect.length == collectResultSize[resultIndex]) { - collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); - collect = collectResults[resultIndex]; - } - collect[collectResultSize[resultIndex]++] = x.getDataAt(i); - } - - Object[] results = new Object[nLevels]; - RStringVector[] resultNames = getNames(x, f, fStore, nLevels, collectResultSize); - for (int i = 0; i < nLevels; i++) { - results[i] = RDataFactory.createIntVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), x.isComplete(), - (resultNames != null) ? resultNames[i] : null); - } - return RDataFactory.createList(results, names); - } - - @Specialization - protected RList split(RAbstractDoubleVector x, RAbstractIntVector f) { - Object fStore = factorAccess.getDataStore(f); - RStringVector names = getLevelNode.execute(f); - final int nLevels = getNLevels(names); - - // initialise result arrays - double[][] collectResults = new double[nLevels][]; - int[] collectResultSize = new int[nLevels]; - for (int i = 0; i < collectResults.length; i++) { - collectResults[i] = new double[INITIAL_SIZE]; - } - - // perform split - int factorLen = f.getLength(); - for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factorLen)) { - int resultIndex = factorAccess.getDataAt(f, fStore, fi) - 1; // a factor is a 1-based - // int vector - double[] collect = collectResults[resultIndex]; - if (collect.length == collectResultSize[resultIndex]) { - collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); - collect = collectResults[resultIndex]; - } - collect[collectResultSize[resultIndex]++] = x.getDataAt(i); - } - - Object[] results = new Object[nLevels]; - RStringVector[] resultNames = getNames(x, f, fStore, nLevels, collectResultSize); - for (int i = 0; i < nLevels; i++) { - results[i] = RDataFactory.createDoubleVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), RDataFactory.COMPLETE_VECTOR, - (resultNames != null) ? resultNames[i] : null); + return RDataFactory.createList(results, names); } - return RDataFactory.createList(results, names); } - @Specialization - protected RList split(RAbstractStringVector x, RAbstractIntVector f) { - Object fStore = factorAccess.getDataStore(f); - RStringVector names = getLevelNode.execute(f); - final int nLevels = getNLevels(names); - - // initialise result arrays - String[][] collectResults = new String[nLevels][]; - int[] collectResultSize = new int[nLevels]; - for (int i = 0; i < collectResults.length; i++) { - collectResults[i] = new String[INITIAL_SIZE]; - } - - // perform split - int factorLen = f.getLength(); - for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factorLen)) { - int resultIndex = factorAccess.getDataAt(f, fStore, fi) - 1; // a factor is a 1-based - // int vector - String[] collect = collectResults[resultIndex]; - if (collect.length == collectResultSize[resultIndex]) { - collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); - collect = collectResults[resultIndex]; - } - collect[collectResultSize[resultIndex]++] = x.getDataAt(i); - } - - Object[] results = new Object[nLevels]; - RStringVector[] resultNames = getNames(x, f, fStore, nLevels, collectResultSize); - for (int i = 0; i < nLevels; i++) { - results[i] = RDataFactory.createStringVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), RDataFactory.COMPLETE_VECTOR, - (resultNames != null) ? resultNames[i] : null); - } - return RDataFactory.createList(results, names); + @Specialization(replaces = "split") + protected RList splitGeneric(RAbstractVector x, RAbstractIntVector f) { + return split(x, f, x.slowPathAccess(), f.slowPathAccess()); } - @Specialization - protected RList split(RAbstractLogicalVector x, RAbstractIntVector f) { - Object fStore = factorAccess.getDataStore(f); - RStringVector names = getLevelNode.execute(f); - final int nLevels = getNLevels(names); - - // initialise result arrays - byte[][] collectResults = new byte[nLevels][]; - int[] collectResultSize = new int[nLevels]; - for (int i = 0; i < collectResults.length; i++) { - collectResults[i] = new byte[INITIAL_SIZE]; - } - - // perform split - int factorLen = f.getLength(); - for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factorLen)) { - int resultIndex = factorAccess.getDataAt(f, fStore, fi) - 1; // a factor is a 1-based - // int vector - byte[] collect = collectResults[resultIndex]; - if (collect.length == collectResultSize[resultIndex]) { - collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); - collect = collectResults[resultIndex]; + protected abstract static class GetSplitNames extends RBaseNode { + + private final ConditionProfile namesProfile = ConditionProfile.createBinaryProfile(); + @Child private GetNamesAttributeNode getNamesNode = GetNamesAttributeNode.create(); + + private RStringVector[] getNames(RAbstractVector x, VectorAccess fAccess, SequentialIterator fIter, int nLevels, int[] collectResultSize) { + RStringVector xNames = getNamesNode.getNames(x); + if (namesProfile.profile(xNames != null)) { + String[][] namesArr = new String[nLevels][]; + int[] resultNamesIdxs = new int[nLevels]; + for (int i = 0; i < nLevels; i++) { + namesArr[i] = new String[collectResultSize[i]]; + } + execute(fAccess, fIter, xNames, namesArr, resultNamesIdxs); + RStringVector[] resultNames = new RStringVector[nLevels]; + for (int i = 0; i < nLevels; i++) { + resultNames[i] = RDataFactory.createStringVector(namesArr[i], xNames.isComplete()); + } + return resultNames; } - collect[collectResultSize[resultIndex]++] = x.getDataAt(i); + return null; } - Object[] results = new Object[nLevels]; - RStringVector[] resultNames = getNames(x, f, fStore, nLevels, collectResultSize); - for (int i = 0; i < nLevels; i++) { - results[i] = RDataFactory.createLogicalVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), x.isComplete(), - (resultNames != null) ? resultNames[i] : null); - } - return RDataFactory.createList(results, names); - } - - @Specialization - protected RList split(RRawVector x, RAbstractIntVector f) { - Object fStore = factorAccess.getDataStore(f); - RStringVector names = getLevelNode.execute(f); - final int nLevels = getNLevels(names); - - // initialise result arrays - byte[][] collectResults = new byte[nLevels][]; - int[] collectResultSize = new int[nLevels]; - for (int i = 0; i < collectResults.length; i++) { - collectResults[i] = new byte[INITIAL_SIZE]; - } - - // perform split - int factorLen = f.getLength(); - for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factorLen)) { - int resultIndex = factorAccess.getDataAt(f, fStore, fi) - 1; // a factor is a 1-based - // int vector - byte[] collect = collectResults[resultIndex]; - if (collect.length == collectResultSize[resultIndex]) { - collectResults[resultIndex] = Arrays.copyOf(collect, collect.length * SCALE_FACTOR); - collect = collectResults[resultIndex]; + protected abstract void execute(VectorAccess fAccess, SequentialIterator fIter, RStringVector names, String[][] namesArr, int[] resultNamesIdxs); + + @Specialization(guards = "namesAccess.supports(names)") + protected void fillNames(VectorAccess fAccess, SequentialIterator fIter, RStringVector names, String[][] namesArr, int[] resultNamesIdxs, + @Cached("names.access()") VectorAccess namesAccess) { + try (SequentialIterator namesIter = namesAccess.access(names)) { + while (namesAccess.next(namesIter)) { + fAccess.nextWithWrap(fIter); + int resultIndex = fAccess.getInt(fIter) - 1; // a factor is a 1-based int + // vector + namesArr[resultIndex][resultNamesIdxs[resultIndex]++] = namesAccess.getString(namesIter); + } } - collect[collectResultSize[resultIndex]++] = x.getRawDataAt(i); } - Object[] results = new Object[nLevels]; - RStringVector[] resultNames = getNames(x, f, fStore, nLevels, collectResultSize); - for (int i = 0; i < nLevels; i++) { - results[i] = RDataFactory.createRawVector(Arrays.copyOfRange(collectResults[i], 0, collectResultSize[i]), - (resultNames != null) ? resultNames[i] : null); - } - return RDataFactory.createList(results, names); - } - - private RStringVector[] getNames(RAbstractVector x, RAbstractIntVector factor, Object fStore, int nLevels, int[] collectResultSize) { - RStringVector xNames = getNamesNode.getNames(x); - if (namesProfile.profile(xNames != null)) { - String[][] namesArr = new String[nLevels][]; - int[] resultNamesIdxs = new int[nLevels]; - for (int i = 0; i < nLevels; i++) { - namesArr[i] = new String[collectResultSize[i]]; - } - int factorLen = factor.getLength(); - for (int i = 0, fi = 0; i < x.getLength(); ++i, fi = Utils.incMod(fi, factorLen)) { - int resultIndex = factorAccess.getDataAt(factor, fStore, fi) - 1; // a factor is a - // 1-based int - // vector - namesArr[resultIndex][resultNamesIdxs[resultIndex]++] = xNames.getDataAt(i); - } - RStringVector[] resultNames = new RStringVector[nLevels]; - for (int i = 0; i < nLevels; i++) { - resultNames[i] = RDataFactory.createStringVector(namesArr[i], xNames.isComplete()); - } - return resultNames; + @Specialization(replaces = "fillNames") + protected void fillNamesGeneric(VectorAccess fAccess, SequentialIterator fIter, RStringVector names, String[][] namesArr, int[] resultNamesIdxs) { + fillNames(fAccess, fIter, names, namesArr, resultNamesIdxs, names.slowPathAccess()); } - return null; } private static int getNLevels(RStringVector levels) { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java index b61abc3dc009fa760ff99ca77b3bcb47ffc58f11..e48d76d60b510b783cc077266c62cd66f2114f8d 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Transpose.java @@ -15,10 +15,6 @@ package com.oracle.truffle.r.nodes.builtin.base; import static com.oracle.truffle.r.runtime.builtins.RBehavior.PURE; import static com.oracle.truffle.r.runtime.builtins.RBuiltinKind.INTERNAL; -import java.util.function.BiFunction; -import java.util.function.Function; - -import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; @@ -35,27 +31,18 @@ import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.function.opt.ReuseNonSharedNode; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.runtime.RError.Message; +import com.oracle.truffle.r.runtime.RInternalError; +import com.oracle.truffle.r.runtime.RType; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RComplex; -import com.oracle.truffle.r.runtime.data.RComplexVector; import com.oracle.truffle.r.runtime.data.RDataFactory; -import com.oracle.truffle.r.runtime.data.RDoubleVector; -import com.oracle.truffle.r.runtime.data.RIntVector; +import com.oracle.truffle.r.runtime.data.RDataFactory.VectorFactory; import com.oracle.truffle.r.runtime.data.RList; -import com.oracle.truffle.r.runtime.data.RLogicalVector; -import com.oracle.truffle.r.runtime.data.RRawVector; import com.oracle.truffle.r.runtime.data.RStringVector; import com.oracle.truffle.r.runtime.data.RVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractComplexVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractDoubleVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractIntVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractListVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractLogicalVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractRawVector; -import com.oracle.truffle.r.runtime.data.model.RAbstractStringVector; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; -import com.oracle.truffle.r.runtime.data.nodes.SetDataAt; -import com.oracle.truffle.r.runtime.data.nodes.VectorReadAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess; +import com.oracle.truffle.r.runtime.data.nodes.VectorAccess.RandomIterator; +import com.oracle.truffle.r.runtime.data.nodes.VectorReuse; import com.oracle.truffle.r.runtime.nodes.RBaseNode; @RBuiltin(name = "t.default", kind = INTERNAL, parameterNames = {"x"}, behavior = PURE) @@ -72,7 +59,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { @Child private SetFixedAttributeNode putDimNames = SetFixedAttributeNode.createDimNames(); @Child private GetDimNamesAttributeNode getDimNamesNode = GetDimNamesAttributeNode.create(); @Child private GetNamesAttributeNode getAxisNamesNode = GetNamesAttributeNode.create(); - @Child private GetDimAttributeNode getDimNode; + @Child private GetDimAttributeNode getDimNode = GetDimAttributeNode.create(); @Child private ReuseNonSharedNode reuseNonShared = ReuseNonSharedNode.create(); static { @@ -81,210 +68,157 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { public abstract Object execute(RAbstractVector o); - @FunctionalInterface - private interface WriteArray<T extends RAbstractVector, A> { - void apply(A array, T vector, int i, int j); - } - - @FunctionalInterface - private interface Swap { - /** Swap element at (i, j) with element at (j, i). */ - void swap(int i, int j); + protected boolean isSquare(RAbstractVector vector) { + int[] dims = getDimNode.getDimensions(vector); + if (GetDimAttributeNode.isMatrix(dims)) { + assert dims.length >= 2; + return dims[0] == dims[1]; + } + return false; } - protected <T extends RAbstractVector, A> RVector<?> transposeInternal(T vector, Function<Integer, A> createArray, WriteArray<T, A> writeArray, BiFunction<A, Boolean, RVector<?>> createResult) { - int length = lengthProfile.profile(vector.getLength()); - int firstDim; - int secondDim; - assert vector.isMatrix(); - int[] dims = getDimensions(vector); - firstDim = dims[0]; - secondDim = dims[1]; - RBaseNode.reportWork(this, length); - - A array = createArray.apply(length); - int j = 0; - loopProfile.profileCounted(length); - for (int i = 0; loopProfile.inject(i < length); i++, j += firstDim) { - if (j > (length - 1)) { - j -= (length - 1); - } - writeArray.apply(array, vector, i, j); - } - RVector<?> r = createResult.apply(array, vector.isComplete()); - // copy attributes - copyRegAttributes.execute(vector, r); - // set new dimensions - int[] newDim = new int[]{secondDim, firstDim}; - putNewDimensions(vector, r, newDim); - return r; + protected boolean isMatrix(RAbstractVector vector) { + return GetDimAttributeNode.isMatrix(getDimNode.getDimensions(vector)); } - protected RVector<?> transposeSquareMatrixInPlace(RVector<?> vector, Object store, VectorReadAccess readAccess, SetDataAt setter, Swap swap) { + private void transposeSquareMatrixInPlace(RAbstractVector vector, RandomIterator iter, VectorAccess access) { int length = lengthProfile.profile(vector.getLength()); - assert vector.isMatrix(); - int[] dims = getDimensions(vector); + assert isMatrix(vector); + int[] dims = getDimNode.getDimensions(vector); assert dims.length == 2; assert dims[0] == dims[1]; int dim = dims[0]; RBaseNode.reportWork(this, length); + RType type = access.getType(); loopProfile.profileCounted(length); for (int i = 0; loopProfile.inject(i < dim); i++) { for (int j = 0; j < i; j++) { int swapi = i * dim + j; int swapj = j * dim + i; - if (swap != null) { - swap.swap(swapi, swapj); - } else { - Object tmp = readAccess.getDataAtAsObject(vector, store, swapi); - Object jVal = readAccess.getDataAtAsObject(vector, store, swapj); - setter.setDataAtAsObject(vector, store, swapi, jVal); - setter.setDataAtAsObject(vector, store, swapj, tmp); + switch (type) { + case Character: { + String tmp = access.getString(iter, swapi); + access.setString(iter, swapi, access.getString(iter, swapj)); + access.setString(iter, swapj, tmp); + break; + } + case Complex: { + double tmpReal = access.getComplexR(iter, swapi); + double tmpImaginary = access.getComplexI(iter, swapi); + access.setComplex(iter, swapi, access.getComplexR(iter, swapj), access.getComplexI(iter, swapj)); + access.setComplex(iter, swapj, tmpReal, tmpImaginary); + break; + } + case Double: { + double tmp = access.getDouble(iter, swapi); + access.setDouble(iter, swapi, access.getDouble(iter, swapj)); + access.setDouble(iter, swapj, tmp); + break; + } + case Integer: { + int tmp = access.getInt(iter, swapi); + access.setInt(iter, swapi, access.getInt(iter, swapj)); + access.setInt(iter, swapj, tmp); + break; + } + case List: { + Object tmp = access.getListElement(iter, swapi); + access.setListElement(iter, swapi, access.getListElement(iter, swapj)); + access.setListElement(iter, swapj, tmp); + break; + } + case Logical: { + byte tmp = access.getLogical(iter, swapi); + access.setLogical(iter, swapi, access.getLogical(iter, swapj)); + access.setLogical(iter, swapj, tmp); + break; + } + case Raw: { + byte tmp = access.getRaw(iter, swapi); + access.setRaw(iter, swapi, access.getRaw(iter, swapj)); + access.setRaw(iter, swapj, tmp); + break; + } + default: + throw RInternalError.shouldNotReachHere(); } } } // don't need to set new dimensions; it is a square matrix putNewDimNames(vector, vector); - return vector; } - private int[] getDimensions(RAbstractVector vector) { - assert vector.isMatrix(); - if (getDimNode == null) { - CompilerDirectives.transferToInterpreterAndInvalidate(); - getDimNode = insert(GetDimAttributeNode.create()); + @Specialization(guards = {"isSquare(x)", "!isRExpression(x)", "xReuse.supports(x)"}) + protected RAbstractVector transposeSquare(RAbstractVector x, + @Cached("createNonShared(x)") VectorReuse xReuse) { + RAbstractVector result = xReuse.getResult(x); + VectorAccess resultAccess = xReuse.access(result); + try (RandomIterator resultIter = resultAccess.randomAccess(result)) { + transposeSquareMatrixInPlace(result, resultIter, resultAccess); } - return getDimNode.getDimensions(vector); - } - - protected boolean isSquare(RAbstractVector vector) { - if (vector.isMatrix()) { - int[] dims = getDimensions(vector); - assert dims.length >= 2; - return dims[0] == dims[1]; + return result; + } + + @Specialization(replaces = "transposeSquare", guards = {"isSquare(x)", "!isRExpression(x)"}) + protected RAbstractVector transposeSquareGeneric(RAbstractVector x, + @Cached("createNonSharedGeneric()") VectorReuse xReuse) { + return transposeSquare(x, xReuse); + } + + @Specialization(guards = {"isMatrix(x)", "!isSquare(x)", "!isRExpression(x)", "xAccess.supports(x)"}) + protected RAbstractVector transpose(RAbstractVector x, + @Cached("create()") VectorFactory factory, + @Cached("x.access()") VectorAccess xAccess, + @Cached("createNew(xAccess.getType())") VectorAccess resultAccess) { + try (RandomIterator xIter = xAccess.randomAccess(x)) { + RAbstractVector result = factory.createVector(xAccess.getType(), xAccess.getLength(xIter), false); + try (RandomIterator resultIter = resultAccess.randomAccess(result)) { + int length = lengthProfile.profile(xAccess.getLength(xIter)); + assert isMatrix(x); + int[] dims = getDimNode.getDimensions(x); + int firstDim = dims[0]; + int secondDim = dims[1]; + RBaseNode.reportWork(this, length); + + int j = 0; + loopProfile.profileCounted(length); + for (int i = 0; loopProfile.inject(i < length); i++, j += firstDim) { + if (j > (length - 1)) { + j -= (length - 1); + } + resultAccess.setFromSameType(resultIter, i, xAccess, xIter, j); + } + // copy attributes + copyRegAttributes.execute(x, result); + // set new dimensions + putNewDimensions(x, result, new int[]{secondDim, firstDim}); + } + result.setComplete(x.isComplete()); + return result; } - return false; - } - - @Specialization(guards = "isSquare(x)") - protected RVector<?> transposeSquare(RAbstractIntVector x, - @Cached("create()") VectorReadAccess.Int readAccess, - @Cached("create()") SetDataAt.Int setter) { - RIntVector reused = (RIntVector) reuseNonShared.execute(x).materialize(); - Object reusedStore = readAccess.getDataStore(reused); - return transposeSquareMatrixInPlace(reused, reusedStore, readAccess, setter, null); - } - - @Specialization(guards = "isSquare(x)") - protected RVector<?> transposeSquare(RAbstractLogicalVector x, - @Cached("create()") VectorReadAccess.Logical readAccess, - @Cached("create()") SetDataAt.Logical setter) { - RLogicalVector reused = (RLogicalVector) reuseNonShared.execute(x).materialize(); - Object reusedStore = readAccess.getDataStore(reused); - return transposeSquareMatrixInPlace(reused, reusedStore, readAccess, setter, null); - } - - @Specialization(guards = "isSquare(x)") - protected RVector<?> transposeSquare(RAbstractDoubleVector x, - @Cached("create()") VectorReadAccess.Double readAccess, - @Cached("create()") SetDataAt.Double setter) { - RDoubleVector reused = (RDoubleVector) reuseNonShared.execute(x).materialize(); - Object reusedStore = readAccess.getDataStore(reused); - return transposeSquareMatrixInPlace(reused, reusedStore, readAccess, setter, null); - } - - @Specialization(guards = "isSquare(x)") - protected RVector<?> transposeSquare(RAbstractComplexVector x, - @Cached("create()") VectorReadAccess.Complex readAccess, - @Cached("create()") SetDataAt.Complex setter) { - RComplexVector reused = (RComplexVector) reuseNonShared.execute(x).materialize(); - Object reusedStore = readAccess.getDataStore(reused); - return transposeSquareMatrixInPlace(reused, reusedStore, readAccess, setter, null); - } - - @Specialization(guards = "isSquare(x)") - protected RVector<?> transposeSquare(RAbstractStringVector x, - @Cached("create()") VectorReadAccess.String readAccess, - @Cached("create()") SetDataAt.String setter) { - RStringVector reused = (RStringVector) reuseNonShared.execute(x).materialize(); - Object reusedStore = readAccess.getDataStore(reused); - return transposeSquareMatrixInPlace(reused, reusedStore, readAccess, setter, null); - } - - @Specialization(guards = "isSquare(x)") - protected RVector<?> transposeSquare(RAbstractListVector x) { - RList reused = (RList) reuseNonShared.execute(x).materialize(); - Object[] store = reused.getDataWithoutCopying(); - return transposeSquareMatrixInPlace(reused, store, null, null, (i, j) -> { - Object tmp = store[i]; - store[i] = store[j]; - store[j] = tmp; - }); - } - - @Specialization(guards = "isSquare(x)") - protected RVector<?> transposeSquare(RAbstractRawVector x, - @Cached("create()") VectorReadAccess.Raw readAccess, - @Cached("create()") SetDataAt.Raw setter) { - RRawVector reused = (RRawVector) reuseNonShared.execute(x).materialize(); - Object reusedStore = readAccess.getDataStore(reused); - return transposeSquareMatrixInPlace(reused, reusedStore, readAccess, setter, null); - } - - @Specialization(guards = {"x.isMatrix()", "!isSquare(x)"}) - protected RVector<?> transpose(RAbstractIntVector x) { - return transposeInternal(x, l -> new int[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createIntVector); - } - - @Specialization(guards = {"x.isMatrix()", "!isSquare(x)"}) - protected RVector<?> transpose(RAbstractLogicalVector x) { - return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createLogicalVector); - } - - @Specialization(guards = {"x.isMatrix()", "!isSquare(x)"}) - protected RVector<?> transpose(RAbstractDoubleVector x) { - return transposeInternal(x, l -> new double[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createDoubleVector); - } - - @Specialization(guards = {"x.isMatrix()", "!isSquare(x)"}) - protected RVector<?> transpose(RAbstractComplexVector x) { - return transposeInternal(x, l -> new double[l * 2], (a, v, i, j) -> { - RComplex d = v.getDataAt(j); - a[i * 2] = d.getRealPart(); - a[i * 2 + 1] = d.getImaginaryPart(); - }, RDataFactory::createComplexVector); - } - - @Specialization(guards = {"x.isMatrix()", "!isSquare(x)"}) - protected RVector<?> transpose(RAbstractStringVector x) { - return transposeInternal(x, l -> new String[l], (a, v, i, j) -> a[i] = v.getDataAt(j), RDataFactory::createStringVector); - } - - @Specialization(guards = {"x.isMatrix()", "!isSquare(x)"}) - protected RVector<?> transpose(RAbstractListVector x) { - return transposeInternal(x, l -> new Object[l], (a, v, i, j) -> a[i] = v.getDataAt(j), (a, c) -> RDataFactory.createList(a)); } - @Specialization(guards = {"x.isMatrix()", "!isSquare(x)"}) - protected RVector<?> transpose(RAbstractRawVector x) { - return transposeInternal(x, l -> new byte[l], (a, v, i, j) -> a[i] = v.getRawDataAt(j), (a, c) -> RDataFactory.createRawVector(a)); + @Specialization(replaces = "transpose", guards = {"isMatrix(x)", "!isSquare(x)", "!isRExpression(x)"}) + protected RAbstractVector transposeGeneric(RAbstractVector x, + @Cached("create()") VectorFactory factory) { + return transpose(x, factory, x.slowPathAccess(), VectorAccess.createSlowPathNew(x.getRType())); } - @Specialization(guards = "!x.isMatrix()") - protected RVector<?> transpose(RAbstractVector x) { + @Specialization(guards = {"!isMatrix(x)", "!isRExpression(x)"}) + protected RVector<?> transposeNonMatrix(RAbstractVector x) { RVector<?> reused = reuseNonShared.execute(x); putNewDimensions(reused, reused, new int[]{1, x.getLength()}); return reused; } - private void putNewDimensions(RAbstractVector source, RVector<?> dest, int[] newDim) { + private void putNewDimensions(RAbstractVector source, RAbstractVector dest, int[] newDim) { putDimensions.execute(initAttributes.execute(dest), RDataFactory.createIntVector(newDim, RDataFactory.COMPLETE_VECTOR)); putNewDimNames(source, dest); } - private void putNewDimNames(RAbstractVector source, RVector<?> dest) { + private void putNewDimNames(RAbstractVector source, RAbstractVector dest) { // set new dim names RList dimNames = getDimNamesNode.getDimNames(source); if (dimNames != null) { @@ -292,8 +226,7 @@ public abstract class Transpose extends RBuiltinNode.Arg1 { assert dimNames.getLength() == 2; RStringVector axisNames = getAxisNamesNode.getNames(dimNames); RStringVector transAxisNames = axisNames == null ? null : RDataFactory.createStringVector(new String[]{axisNames.getDataAt(1), axisNames.getDataAt(0)}, true); - RList newDimNames = RDataFactory.createList(new Object[]{dimNames.getDataAt(1), - dimNames.getDataAt(0)}, transAxisNames); + RList newDimNames = RDataFactory.createList(new Object[]{dimNames.getDataAt(1), dimNames.getDataAt(0)}, transAxisNames); putDimNames.execute(dest.getAttributes(), newDimNames); } }