diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java index 853a071cb5cb0218d803cb97f519d208c62ec607..ce9e8aef90d9a161d4cce2e8638082cb241b0c2d 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Bind.java @@ -38,6 +38,7 @@ import com.oracle.truffle.api.profiles.ValueProfile; import com.oracle.truffle.r.nodes.RASTUtils; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.function.S3FunctionLookupNode; @@ -172,7 +173,8 @@ public abstract class Bind extends RBaseNode { } } - private Object bindInternal(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, CastNode castNode, boolean needsVectorCast, SetDimAttributeNode setDimNode) { + private Object bindInternal(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, CastNode castNode, boolean needsVectorCast, SetDimAttributeNode setDimNode, + SetDimNamesAttributeNode setDimNamesNode) { ArgumentsSignature signature = promiseArgs.getSignature(); String[] vecNames = nullNamesProfile.profile(signature.getNonNullCount() == 0) ? null : new String[signature.getLength()]; RAbstractVector[] vectors = new RAbstractVector[args.length]; @@ -228,52 +230,58 @@ public abstract class Bind extends RBaseNode { } } if (type == BindType.cbind) { - return genericCBind(promiseArgs, vectors, complete, vecNames, naCheck.neverSeenNA(), deparseLevel, setDimNode); + return genericCBind(promiseArgs, vectors, complete, vecNames, naCheck.neverSeenNA(), deparseLevel, setDimNode, setDimNamesNode); } else { - return genericRBind(promiseArgs, vectors, complete, vecNames, naCheck.neverSeenNA(), deparseLevel, setDimNode); + return genericRBind(promiseArgs, vectors, complete, vecNames, naCheck.neverSeenNA(), deparseLevel, setDimNode, setDimNamesNode); } } @Specialization(guards = {"precedence == LOGICAL_PRECEDENCE", "args.length > 1", "!isDataFrame(args)"}) protected Object allLogical(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence, // @Cached("create()") CastLogicalNode cast, - @Cached("create()") SetDimAttributeNode setDimNode) { - return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode); + @Cached("create()") SetDimAttributeNode setDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode, setDimNamesNode); } @Specialization(guards = {"precedence == INT_PRECEDENCE", "args.length > 1", "!isDataFrame(args)"}) protected Object allInt(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence, // @Cached("create()") CastIntegerNode cast, - @Cached("create()") SetDimAttributeNode setDimNode) { - return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode); + @Cached("create()") SetDimAttributeNode setDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode, setDimNamesNode); } @Specialization(guards = {"precedence == DOUBLE_PRECEDENCE", "args.length > 1", "!isDataFrame(args)"}) protected Object allDouble(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence, // @Cached("create()") CastDoubleNode cast, - @Cached("create()") SetDimAttributeNode setDimNode) { - return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode); + @Cached("create()") SetDimAttributeNode setDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode, setDimNamesNode); } @Specialization(guards = {"precedence == STRING_PRECEDENCE", "args.length> 1", "!isDataFrame(args)"}) protected Object allString(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence, // @Cached("create()") CastStringNode cast, - @Cached("create()") SetDimAttributeNode setDimNode) { - return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode); + @Cached("create()") SetDimAttributeNode setDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode, setDimNamesNode); } @Specialization(guards = {"precedence == COMPLEX_PRECEDENCE", "args.length > 1", "!isDataFrame(args)"}) protected Object allComplex(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence, // @Cached("create()") CastComplexNode cast, - @Cached("create()") SetDimAttributeNode setDimNode) { - return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode); + @Cached("create()") SetDimAttributeNode setDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return bindInternal(deparseLevel, args, promiseArgs, cast, true, setDimNode, setDimNamesNode); } @Specialization(guards = {"precedence == LIST_PRECEDENCE", "args.length > 1", "!isDataFrame(args)"}) protected Object allList(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence, // @Cached("create()") CastListNode cast, - @Cached("create()") SetDimAttributeNode setDimNode) { - return bindInternal(deparseLevel, args, promiseArgs, cast, false, setDimNode); + @Cached("create()") SetDimAttributeNode setDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return bindInternal(deparseLevel, args, promiseArgs, cast, false, setDimNode, setDimNamesNode); } /** @@ -462,7 +470,8 @@ public abstract class Bind extends RBaseNode { private final BranchProfile everSeenNotEqualColumns = BranchProfile.create(); @Specialization(guards = {"precedence != NO_PRECEDENCE", "args.length == 1"}) - protected Object allOneElem(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence) { + protected Object allOneElem(int deparseLevel, Object[] args, RArgsValuesAndNames promiseArgs, @SuppressWarnings("unused") int precedence, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { RAbstractVector vec = castVector(args[0]); if (vec.isMatrix()) { return vec; @@ -491,13 +500,13 @@ public abstract class Bind extends RBaseNode { } RVector<?> res = (RVector<?>) vec.copyWithNewDimensions(dims); - res.setDimNames(RDataFactory.createList(type == BindType.cbind ? new Object[]{dimNamesA, dimNamesB} : new Object[]{dimNamesB, dimNamesA})); + setDimNamesNode.execute(res, RDataFactory.createList(type == BindType.cbind ? new Object[]{dimNamesA, dimNamesB} : new Object[]{dimNamesB, dimNamesA})); res.copyRegAttributesFrom(vec); return res; } public RVector<?> genericCBind(RArgsValuesAndNames promiseArgs, RAbstractVector[] vectors, boolean complete, String[] vecNames, boolean vecNamesComplete, int deparseLevel, - SetDimAttributeNode setDimNode) { + SetDimAttributeNode setDimNode, SetDimNamesAttributeNode setDimNamesNode) { int[] resultDimensions = new int[2]; int[] secondDims = new int[vectors.length]; @@ -549,7 +558,7 @@ public abstract class Bind extends RBaseNode { } Object colDimResultNames = allColDimNamesNull ? RNull.instance : RDataFactory.createStringVector(colDimNamesArray, vecNamesComplete); setDimNode.setDimensions(result, resultDimensions); - result.setDimNames(RDataFactory.createList(new Object[]{rowDimResultNames, colDimResultNames})); + setDimNamesNode.setDimNames(result, RDataFactory.createList(new Object[]{rowDimResultNames, colDimResultNames})); return result; } @@ -579,7 +588,7 @@ public abstract class Bind extends RBaseNode { } public RVector<?> genericRBind(RArgsValuesAndNames promiseArgs, RAbstractVector[] vectors, boolean complete, String[] vecNames, boolean vecNamesComplete, int deparseLevel, - SetDimAttributeNode setDimNode) { + SetDimAttributeNode setDimNode, SetDimNamesAttributeNode setDimNamesNode) { int[] resultDimensions = new int[2]; int[] firstDims = new int[vectors.length]; @@ -635,7 +644,7 @@ public abstract class Bind extends RBaseNode { } Object rowDimResultNames = allRowDimNamesNull ? RNull.instance : RDataFactory.createStringVector(rowDimNamesArray, vecNamesComplete); setDimNode.setDimensions(result, resultDimensions); - result.setDimNames(RDataFactory.createList(new Object[]{rowDimResultNames, colDimResultNames})); + setDimNamesNode.setDimNames(result, RDataFactory.createList(new Object[]{rowDimResultNames, colDimResultNames})); return result; } } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java index e59eb2809cc9c0ab9488e9fc92da42880af3a5b0..dec2884830fda1f105f25294440f528c9eadba3b 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Crossprod.java @@ -30,6 +30,7 @@ import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; @@ -66,14 +67,15 @@ public abstract class Crossprod extends RBuiltinNode { @Specialization(guards = {"x.isMatrix()", "y.isMatrix()"}) protected RDoubleVector crossprod(RAbstractDoubleVector x, RAbstractDoubleVector y, @Cached("create()") GetDimAttributeNode getXDimsNode, - @Cached("create()") GetDimAttributeNode getYDimsNode) { + @Cached("create()") GetDimAttributeNode getYDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { int[] xDims = getXDimsNode.getDimensions(x); int[] yDims = getYDimsNode.getDimensions(y); int xRows = xDims[0]; int xCols = xDims[1]; int yRows = yDims[0]; int yCols = yDims[1]; - return matMult.doubleMatrixMultiply(x, y, xCols, xRows, yRows, yCols, xRows, 1, 1, yRows, false); + return matMult.doubleMatrixMultiply(x, y, xCols, xRows, yRows, yCols, xRows, 1, 1, yRows, false, setDimNamesNode); } private static RDoubleVector mirror(RDoubleVector result, GetDimAttributeNode getResultDimsNode) { @@ -104,11 +106,13 @@ public abstract class Crossprod extends RBuiltinNode { @Specialization(guards = "x.isMatrix()") protected RDoubleVector crossprodDoubleMatrix(RAbstractDoubleVector x, @SuppressWarnings("unused") RNull y, - @Cached("create()") GetDimAttributeNode getDimsNode, @Cached("create()") GetDimAttributeNode getResultDimsNode) { + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") GetDimAttributeNode getResultDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { int[] xDims = getDimsNode.getDimensions(x); int xRows = xDims[0]; int xCols = xDims[1]; - return mirror(matMult.doubleMatrixMultiply(x, x, xCols, xRows, xRows, xCols, xRows, 1, 1, xRows, true), getResultDimsNode); + return mirror(matMult.doubleMatrixMultiply(x, x, xCols, xRows, xRows, xCols, xRows, 1, 1, xRows, true, setDimNamesNode), getResultDimsNode); } @Specialization diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Drop.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Drop.java index 0ff9a4a4b01e3264fd6c35ddf61e0ea5c546deae..b968a3e2d1faeb53f2479913e35f9e4f71f50f29 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Drop.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/Drop.java @@ -31,6 +31,7 @@ import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.builtins.RBuiltin; import com.oracle.truffle.r.runtime.data.RAttributeProfiles; @@ -52,7 +53,8 @@ public abstract class Drop extends RBuiltinNode { @Specialization protected RAbstractVector doDrop(RAbstractVector x, @Cached("create()") GetDimAttributeNode getDimsNode, - @Cached("create()") SetDimAttributeNode setDimsNode) { + @Cached("create()") SetDimAttributeNode setDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimsNamseNode) { int[] dims = getDimsNode.getDimensions(x); if (nullDimensions.profile(dims == null)) { return x; @@ -73,7 +75,7 @@ public abstract class Drop extends RBuiltinNode { @SuppressWarnings("unused") RAbstractVector r = x.copy(); setDimsNode.setDimensions(x, null); - x.setDimNames(null); + setDimsNamseNode.setDimNames(x, null); x.setNames(null); return x; } @@ -105,9 +107,9 @@ public abstract class Drop extends RBuiltinNode { newDimNames[newDimsIdx++] = oldDimNames.getDataAt(i); } } - result.setDimNames(RDataFactory.createList(newDimNames)); + setDimsNamseNode.setDimNames(result, RDataFactory.createList(newDimNames)); } else { - result.setDimNames(null); + setDimsNamseNode.setDimNames(result, null); } return result; diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java index 13dc1bcad62fb47e61c0beebf0f6e1a5cf7a9f2c..59818d65964a964481c7164c9c0ce9c0064e01f4 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/IsNA.java @@ -33,6 +33,7 @@ import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; @@ -76,12 +77,14 @@ public abstract class IsNA extends RBuiltinNode { } @Specialization - protected RLogicalVector isNA(RAbstractIntVector vector, @Cached("create()") GetDimAttributeNode getDimsNode) { + protected RLogicalVector isNA(RAbstractIntVector vector, + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { byte[] resultVector = new byte[vector.getLength()]; for (int i = 0; i < vector.getLength(); i++) { resultVector[i] = RRuntime.asLogical(RRuntime.isNA(vector.getDataAt(i))); } - return createResult(resultVector, vector, getDimsNode); + return createResult(resultVector, vector, getDimsNode, setDimNamesNode); } @Specialization @@ -90,22 +93,26 @@ public abstract class IsNA extends RBuiltinNode { } @Specialization - protected RLogicalVector isNA(RAbstractDoubleVector vector, @Cached("create()") GetDimAttributeNode getDimsNode) { + protected RLogicalVector isNA(RAbstractDoubleVector vector, + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { byte[] resultVector = new byte[vector.getLength()]; for (int i = 0; i < vector.getLength(); i++) { resultVector[i] = RRuntime.asLogical(RRuntime.isNAorNaN(vector.getDataAt(i))); } - return createResult(resultVector, vector, getDimsNode); + return createResult(resultVector, vector, getDimsNode, setDimNamesNode); } @Specialization - protected RLogicalVector isNA(RComplexVector vector, @Cached("create()") GetDimAttributeNode getDimsNode) { + protected RLogicalVector isNA(RComplexVector vector, + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { byte[] resultVector = new byte[vector.getLength()]; for (int i = 0; i < vector.getLength(); i++) { RComplex complex = vector.getDataAt(i); resultVector[i] = RRuntime.asLogical(RRuntime.isNA(complex)); } - return createResult(resultVector, vector, getDimsNode); + return createResult(resultVector, vector, getDimsNode, setDimNamesNode); } @Specialization @@ -114,12 +121,14 @@ public abstract class IsNA extends RBuiltinNode { } @Specialization - protected RLogicalVector isNA(RStringVector vector, @Cached("create()") GetDimAttributeNode getDimsNode) { + protected RLogicalVector isNA(RStringVector vector, + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { byte[] resultVector = new byte[vector.getLength()]; for (int i = 0; i < vector.getLength(); i++) { resultVector[i] = RRuntime.asLogical(RRuntime.isNA(vector.getDataAt(i))); } - return createResult(resultVector, vector, getDimsNode); + return createResult(resultVector, vector, getDimsNode, setDimNamesNode); } @Specialization @@ -154,12 +163,14 @@ public abstract class IsNA extends RBuiltinNode { } @Specialization - protected RLogicalVector isNA(RLogicalVector vector, @Cached("create()") GetDimAttributeNode getDimsNode) { + protected RLogicalVector isNA(RLogicalVector vector, + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { byte[] resultVector = new byte[vector.getLength()]; for (int i = 0; i < vector.getLength(); i++) { resultVector[i] = (RRuntime.isNA(vector.getDataAt(i)) ? RRuntime.LOGICAL_TRUE : RRuntime.LOGICAL_FALSE); } - return createResult(resultVector, vector, getDimsNode); + return createResult(resultVector, vector, getDimsNode, setDimNamesNode); } @Specialization @@ -173,12 +184,14 @@ public abstract class IsNA extends RBuiltinNode { } @Specialization - protected RLogicalVector isNA(RRawVector vector, @Cached("create()") GetDimAttributeNode getDimsNode) { + protected RLogicalVector isNA(RRawVector vector, + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { byte[] resultVector = new byte[vector.getLength()]; for (int i = 0; i < vector.getLength(); i++) { resultVector[i] = RRuntime.LOGICAL_FALSE; } - return createResult(resultVector, vector, getDimsNode); + return createResult(resultVector, vector, getDimsNode, setDimNamesNode); } @Specialization @@ -195,11 +208,11 @@ public abstract class IsNA extends RBuiltinNode { return RRuntime.LOGICAL_FALSE; } - private RLogicalVector createResult(byte[] data, RAbstractVector originalVector, GetDimAttributeNode getDimsNode) { + private RLogicalVector createResult(byte[] data, RAbstractVector originalVector, GetDimAttributeNode getDimsNode, SetDimNamesAttributeNode setDimNamesNode) { RLogicalVector result = RDataFactory.createLogicalVector(data, RDataFactory.COMPLETE_VECTOR, getDimsNode.getDimensions(originalVector), originalVector.getNames(attrProfiles)); RList dimNames = originalVector.getDimNames(attrProfiles); if (nullDimNamesProfile.profile(dimNames != null)) { - result.setDimNames(dimNames); + setDimNamesNode.setDimNames(result, dimNames); } return result; } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java index 5e0c81fa1dfb924c8bf73d2341115114e8247105..7fa39325e78b94f53a5382f965f586124fae1cb2 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/LaFunctions.java @@ -35,6 +35,7 @@ import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.unary.CastDoubleNode; @@ -484,7 +485,8 @@ public class LaFunctions { @Specialization protected RDoubleVector doDetGeReal(RDoubleVector aIn, boolean piv, double tol, - @Cached("create()") GetDimAttributeNode getDimsNode) { + @Cached("create()") GetDimAttributeNode getDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { RDoubleVector a = (RDoubleVector) aIn.copy(); int[] aDims = getDimsNode.getDimensions(aIn); int n = aDims[0]; @@ -523,7 +525,7 @@ public class LaFunctions { for (int i = 0; i < m; i++) { dn2[i] = dn.getDataAt(ipiv[i] - 1); } - a.setDimNames(RDataFactory.createList(dn2)); + setDimNamesNode.setDimNames(a, RDataFactory.createList(dn2)); } } return a; @@ -558,7 +560,8 @@ public class LaFunctions { protected RDoubleVector laSolve(RAbstractVector a, RDoubleVector bin, double tol, @Cached("create()") GetDimAttributeNode getADimsNode, @Cached("create()") GetDimAttributeNode getBinDimsNode, - @Cached("create()") SetDimAttributeNode setBDimsNode) { + @Cached("create()") SetDimAttributeNode setBDimsNode, + @Cached("create()") SetDimNamesAttributeNode setBDimNamesNode) { int[] aDims = getADimsNode.getDimensions(a); int n = aDims[0]; if (n == 0) { @@ -597,7 +600,7 @@ public class LaFunctions { bDnData[1] = binDn.getDataAt(1); } if (bDnData[0] != null || bDnData[1] != null) { - b.setDimNames(RDataFactory.createList(bDnData)); + setBDimNamesNode.setDimNames(b, RDataFactory.createList(bDnData)); } } } else { diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java index 616b798d932dda8bf7da680e6fd6d137d0d66f3d..85ca79cd914681006c88b73b439dec0a42895154 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/MatMult.java @@ -36,6 +36,7 @@ import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.LoopConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.binary.BinaryMapArithmeticFunctionNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RError; @@ -141,8 +142,8 @@ public abstract class MatMult extends RBuiltinNode { private final BranchProfile incompleteProfile = BranchProfile.create(); @CompilationFinal private boolean seenLargeMatrix; - private RDoubleVector doubleMatrixMultiply(RAbstractDoubleVector a, RAbstractDoubleVector b, int aRows, int aCols, int bRows, int bCols) { - return doubleMatrixMultiply(a, b, aRows, aCols, bRows, bCols, 1, aRows, 1, bRows, false); + private RDoubleVector doubleMatrixMultiply(RAbstractDoubleVector a, RAbstractDoubleVector b, int aRows, int aCols, int bRows, int bCols, SetDimNamesAttributeNode setDimNamesNode) { + return doubleMatrixMultiply(a, b, aRows, aCols, bRows, bCols, 1, aRows, 1, bRows, false, setDimNamesNode); } /** @@ -163,7 +164,7 @@ public abstract class MatMult extends RBuiltinNode { * @return the result vector */ public RDoubleVector doubleMatrixMultiply(RAbstractDoubleVector a, RAbstractDoubleVector b, int aRows, int aCols, int bRows, int bCols, int aRowStride, int aColStride, int bRowStride, - int bColStride, boolean mirrored) { + int bColStride, boolean mirrored, SetDimNamesAttributeNode setDimNamesNode) { if (aCols != bRows) { errorProfile.enter(); throw RError.error(this, RError.Message.NON_CONFORMABLE_ARGS); @@ -228,7 +229,7 @@ public abstract class MatMult extends RBuiltinNode { if (bDimNames != null && bDimNames.getLength() > 1) { newDimsNames[1] = bDimNames.getDataAt(1); } - resultVec.setDimNames(RDataFactory.createList(newDimsNames)); + setDimNamesNode.setDimNames(resultVec, RDataFactory.createList(newDimsNames)); return resultVec; } @@ -273,12 +274,13 @@ public abstract class MatMult extends RBuiltinNode { protected RDoubleVector multiply(RAbstractDoubleVector a, RAbstractDoubleVector b, @Cached("createBinaryProfile()") ConditionProfile aIsMatrix, @Cached("createBinaryProfile()") ConditionProfile bIsMatrix, - @Cached("createBinaryProfile()") ConditionProfile lengthEquals) { + @Cached("createBinaryProfile()") ConditionProfile lengthEquals, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { if (aIsMatrix.profile(a.isMatrix())) { if (bIsMatrix.profile(b.isMatrix())) { int[] aDimensions = getADimsNode.getDimensions(a); int[] bDimensions = getBDimsNode.getDimensions(b); - return doubleMatrixMultiply(a, b, aDimensions[0], aDimensions[1], bDimensions[0], bDimensions[1]); + return doubleMatrixMultiply(a, b, aDimensions[0], aDimensions[1], bDimensions[0], bDimensions[1], setDimNamesNode); } else { int[] aDim = getADimsNode.getDimensions(a); int aRows = aDim[0]; @@ -292,7 +294,7 @@ public abstract class MatMult extends RBuiltinNode { bRows = 1; bCols = b.getLength(); } - return doubleMatrixMultiply(a, b, aRows, aCols, bRows, bCols); + return doubleMatrixMultiply(a, b, aRows, aCols, bRows, bCols, setDimNamesNode); } } else { if (bIsMatrix.profile(b.isMatrix())) { @@ -308,7 +310,7 @@ public abstract class MatMult extends RBuiltinNode { aRows = a.getLength(); aCols = 1; } - return doubleMatrixMultiply(a, b, aRows, aCols, bRows, bCols); + return doubleMatrixMultiply(a, b, aRows, aCols, bRows, bCols, setDimNamesNode); } else { if (a.getLength() != b.getLength()) { errorProfile.enter(); @@ -640,32 +642,36 @@ public abstract class MatMult extends RBuiltinNode { protected RDoubleVector multiply(RAbstractIntVector a, RAbstractDoubleVector b, @Cached("createBinaryProfile()") ConditionProfile aIsMatrix, @Cached("createBinaryProfile()") ConditionProfile bIsMatrix, - @Cached("createBinaryProfile()") ConditionProfile lengthEquals) { - return multiply(RClosures.createIntToDoubleVector(a), b, aIsMatrix, bIsMatrix, lengthEquals); + @Cached("createBinaryProfile()") ConditionProfile lengthEquals, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return multiply(RClosures.createIntToDoubleVector(a), b, aIsMatrix, bIsMatrix, lengthEquals, setDimNamesNode); } @Specialization protected RDoubleVector multiply(RAbstractDoubleVector a, RAbstractIntVector b, @Cached("createBinaryProfile()") ConditionProfile aIsMatrix, @Cached("createBinaryProfile()") ConditionProfile bIsMatrix, - @Cached("createBinaryProfile()") ConditionProfile lengthEquals) { - return multiply(a, RClosures.createIntToDoubleVector(b), aIsMatrix, bIsMatrix, lengthEquals); + @Cached("createBinaryProfile()") ConditionProfile lengthEquals, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return multiply(a, RClosures.createIntToDoubleVector(b), aIsMatrix, bIsMatrix, lengthEquals, setDimNamesNode); } @Specialization protected RDoubleVector multiply(RAbstractLogicalVector a, RAbstractDoubleVector b, @Cached("createBinaryProfile()") ConditionProfile aIsMatrix, @Cached("createBinaryProfile()") ConditionProfile bIsMatrix, - @Cached("createBinaryProfile()") ConditionProfile lengthEquals) { - return multiply(RClosures.createLogicalToDoubleVector(a), b, aIsMatrix, bIsMatrix, lengthEquals); + @Cached("createBinaryProfile()") ConditionProfile lengthEquals, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return multiply(RClosures.createLogicalToDoubleVector(a), b, aIsMatrix, bIsMatrix, lengthEquals, setDimNamesNode); } @Specialization protected RDoubleVector multiply(RAbstractDoubleVector a, RAbstractLogicalVector b, @Cached("createBinaryProfile()") ConditionProfile aIsMatrix, @Cached("createBinaryProfile()") ConditionProfile bIsMatrix, - @Cached("createBinaryProfile()") ConditionProfile lengthEquals) { - return multiply(a, RClosures.createLogicalToDoubleVector(b), aIsMatrix, bIsMatrix, lengthEquals); + @Cached("createBinaryProfile()") ConditionProfile lengthEquals, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { + return multiply(a, RClosures.createLogicalToDoubleVector(b), aIsMatrix, bIsMatrix, lengthEquals, setDimNamesNode); } // errors diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NChar.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NChar.java index cb00025437bdadb08565c25e66bc32b2503a204a..16448cf18eef95829a39df2a9dded14bd48c3742 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NChar.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/NChar.java @@ -34,6 +34,7 @@ import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.LoopConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.CastBuilder; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.runtime.RRuntime; @@ -70,7 +71,8 @@ public abstract class NChar extends RBuiltinNode { @Cached("createCountingProfile()") LoopConditionProfile loopProfile, @Cached("create()") RAttributeProfiles attrProfiles, @Cached("createBinaryProfile()") ConditionProfile nullDimNamesProfile, - @Cached("create()") GetDimAttributeNode getDimNode) { + @Cached("create()") GetDimAttributeNode getDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { int len = vector.getLength(); int[] result = new int[len]; loopProfile.profileCounted(len); @@ -85,7 +87,7 @@ public abstract class NChar extends RBuiltinNode { RIntVector resultVector = RDataFactory.createIntVector(result, true, getDimNode.getDimensions(vector), vector.getNames(attrProfiles)); RList dimNames = vector.getDimNames(attrProfiles); if (nullDimNamesProfile.profile(dimNames != null)) { - resultVector.setDimNames(dimNames); + setDimNamesNode.setDimNames(resultVector, dimNames); } return resultVector; } @@ -96,7 +98,8 @@ public abstract class NChar extends RBuiltinNode { @Cached("createCountingProfile()") LoopConditionProfile loopProfile, @Cached("create()") RAttributeProfiles attrProfiles, @Cached("createBinaryProfile()") ConditionProfile nullDimNamesProfile, - @Cached("create()") GetDimAttributeNode getDimNode) { + @Cached("create()") GetDimAttributeNode getDimNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { int len = vector.getLength(); int[] result = new int[len]; loopProfile.profileCounted(len); @@ -106,7 +109,7 @@ public abstract class NChar extends RBuiltinNode { RIntVector resultVector = RDataFactory.createIntVector(result, true, getDimNode.getDimensions(vector), vector.getNames(attrProfiles)); RList dimNames = vector.getDimNames(attrProfiles); if (nullDimNamesProfile.profile(dimNames != null)) { - resultVector.setDimNames(dimNames); + setDimNamesNode.setDimNames(resultVector, dimNames); } return resultVector; } diff --git a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateDimNames.java b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateDimNames.java index 160eb65328e509c85b9380e7359ee9344d4be57f..6983ebd12510e578c6d7934d4c476b231cbbf331 100644 --- a/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateDimNames.java +++ b/com.oracle.truffle.r.nodes.builtin/src/com/oracle/truffle/r/nodes/builtin/base/UpdateDimNames.java @@ -31,8 +31,7 @@ import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.RemoveFixedAttributeNode; -import com.oracle.truffle.r.nodes.attributes.SetFixedAttributeNode; -import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.builtin.RBuiltinNode; import com.oracle.truffle.r.nodes.unary.CastStringNode; import com.oracle.truffle.r.nodes.unary.CastStringNodeGen; @@ -41,11 +40,8 @@ import com.oracle.truffle.r.nodes.unary.CastToVectorNodeGen; import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.builtins.RBuiltin; -import com.oracle.truffle.r.runtime.data.RAttributesLayout; import com.oracle.truffle.r.runtime.data.RList; import com.oracle.truffle.r.runtime.data.RNull; -import com.oracle.truffle.r.runtime.data.RStringVector; -import com.oracle.truffle.r.runtime.data.RVector; import com.oracle.truffle.r.runtime.data.model.RAbstractContainer; import com.oracle.truffle.r.runtime.data.model.RAbstractVector; @@ -109,10 +105,9 @@ public abstract class UpdateDimNames extends RBuiltinNode { @Specialization(guards = "list.getLength() > 0") protected RAbstractContainer updateDimnames(RAbstractContainer container, RList list, // - @Cached("createDimNames()") SetFixedAttributeNode attrSetter, - @Cached("create()") GetDimAttributeNode getDimNode) { + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { RAbstractContainer result = (RAbstractContainer) container.getNonShared(); - setDimNames(result, convertToListOfStrings(list), attrSetter, getDimNode); + setDimNamesNode.setDimNames(result, convertToListOfStrings(list)); return result; } @@ -122,57 +117,4 @@ public abstract class UpdateDimNames extends RBuiltinNode { throw RError.error(this, RError.Message.DIMNAMES_LIST); } - private void setDimNames(RAbstractContainer container, RList newDimNames, SetFixedAttributeNode attrSetter, GetDimAttributeNode getDimNode) { - assert newDimNames != null; - if (isRVectorProfile.profile(container instanceof RVector)) { - RVector<?> vector = (RVector<?>) container; - int[] dimensions = getDimNode.getDimensions(vector); - if (dimensions == null) { - CompilerDirectives.transferToInterpreter(); - throw RError.error(this, RError.Message.DIMNAMES_NONARRAY); - } - int newDimNamesLength = newDimNames.getLength(); - if (newDimNamesLength > dimensions.length) { - CompilerDirectives.transferToInterpreter(); - throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_DIMS, newDimNamesLength, dimensions.length); - } - for (int i = 0; i < newDimNamesLength; i++) { - Object dimObject = newDimNames.getDataAt(i); - if (dimObject != RNull.instance) { - if (dimObject instanceof String) { - if (dimensions[i] != 1) { - CompilerDirectives.transferToInterpreter(); - throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_EXTENT, i + 1); - } - } else { - RStringVector dimVector = (RStringVector) dimObject; - if (dimVector == null || dimVector.getLength() == 0) { - newDimNames.updateDataAt(i, RNull.instance, null); - } else if (dimVector.getLength() != dimensions[i]) { - CompilerDirectives.transferToInterpreter(); - throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_EXTENT, i + 1); - } - } - } - } - - RList resDimNames = newDimNames; - if (newDimNamesLength < dimensions.length) { - // resize the array and fill the missing entries with NULL-s - resDimNames = (RList) resDimNames.copyResized(dimensions.length, true); - resDimNames.setAttributes(newDimNames); - for (int i = newDimNamesLength; i < dimensions.length; i++) { - resDimNames.updateDataAt(i, RNull.instance, null); - } - } - if (vector.getAttributes() == null) { - vector.initAttributes(RAttributesLayout.createDimNames(resDimNames)); - } else { - attrSetter.execute(vector.getAttributes(), resDimNames); - } - resDimNames.elementNamePrefix = RRuntime.DIMNAMES_LIST_ELEMENT_NAME_PREFIX; - } else { - container.setDimNames(newDimNames); - } - } } diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/CachedExtractVectorNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/CachedExtractVectorNode.java index 805339e2673281bc6819bc4d3cc10b1d08f18946..a9b9c15b75f4df5c27f0cd00bf021c100f800b10 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/CachedExtractVectorNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/access/vector/CachedExtractVectorNode.java @@ -34,6 +34,7 @@ import com.oracle.truffle.r.nodes.access.vector.CachedExtractVectorNodeFactory.S import com.oracle.truffle.r.nodes.access.vector.PositionsCheckNode.PositionProfile; import com.oracle.truffle.r.nodes.attributes.GetFixedAttributeNode; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.nodes.profile.AlwaysOnBranchProfile; import com.oracle.truffle.r.nodes.profile.VectorLengthProfile; import com.oracle.truffle.r.runtime.RError; @@ -281,6 +282,7 @@ final class CachedExtractVectorNode extends CachedVectorNode { private final ConditionProfile foundNamesProfile = ConditionProfile.createBinaryProfile(); @Child private SetDimAttributeNode setDimNode; + @Child private SetDimNamesAttributeNode setDimNamesNode; @ExplodeLoop private void applyDimensions(RAbstractContainer originalTarget, RVector<?> extractedTarget, int extractedTargetLength, PositionProfile[] positionProfile, Object[] positions) { @@ -336,7 +338,12 @@ final class CachedExtractVectorNode extends CachedVectorNode { setDimNode.setDimensions(extractedTarget, newDimensions); if (newDimNames != null) { - extractedTarget.setDimNames(RDataFactory.createList(newDimNames, newDimNamesNames == null ? null : RDataFactory.createStringVector(newDimNamesNames, originalDimNames.isComplete()))); + if (setDimNamesNode == null) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + setDimNamesNode = insert(SetDimNamesAttributeNode.create()); + } + setDimNamesNode.setDimNames(extractedTarget, + RDataFactory.createList(newDimNames, newDimNamesNames == null ? null : RDataFactory.createStringVector(newDimNamesNames, originalDimNames.isComplete()))); } } else if (newDimNames != null && originalDimNamesPRofile.profile(originalDimNames.getLength() > 0)) { RAbstractStringVector foundNames = translateDimNamesToNames(positionProfile, originalDimNames, extractedTargetLength, positions); diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/CopyAttributesNode.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/CopyAttributesNode.java index 62fb60138bb7ae74e342c352b7970c34f5b786c2..ff585c63951527c39f63cf2a0b7b47c33c4a8017 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/CopyAttributesNode.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/CopyAttributesNode.java @@ -29,6 +29,7 @@ import com.oracle.truffle.api.object.DynamicObject; import com.oracle.truffle.api.profiles.BranchProfile; import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.GetDimAttributeNode; +import com.oracle.truffle.r.nodes.attributes.SpecialAttributesFunctions.SetDimNamesAttributeNode; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RAttributeProfiles; import com.oracle.truffle.r.runtime.data.RDataFactory; @@ -129,7 +130,8 @@ public abstract class CopyAttributesNode extends RBaseNode { @Cached("createBinaryProfile()") ConditionProfile hasNamesRight, @Cached("createBinaryProfile()") ConditionProfile hasDimNames, @Cached("create()") GetDimAttributeNode getLeftDimsNode, - @Cached("create()") GetDimAttributeNode getRightDimsNode) { + @Cached("create()") GetDimAttributeNode getRightDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { if (LOG) { log("copyAttributes: =="); countEquals++; @@ -191,7 +193,7 @@ public abstract class CopyAttributesNode extends RBaseNode { if (result != right) { newDimNames = right.getDimNames(attrRightProfiles); if (hasDimNames.profile(newDimNames != null)) { - result.setDimNames(newDimNames); + setDimNamesNode.setDimNames(result, newDimNames); } } } @@ -211,7 +213,8 @@ public abstract class CopyAttributesNode extends RBaseNode { @Cached("createBinaryProfile()") ConditionProfile hasNames, // @Cached("createBinaryProfile()") ConditionProfile hasDimNames, @Cached("create()") GetDimAttributeNode getLeftDimsNode, - @Cached("create()") GetDimAttributeNode getRightDimsNode) { + @Cached("create()") GetDimAttributeNode getRightDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { if (LOG) { log("copyAttributes: <"); countSmaller++; @@ -249,7 +252,7 @@ public abstract class CopyAttributesNode extends RBaseNode { if (rightNotResult) { RList newDimNames = right.getDimNames(attrRightProfiles); if (hasDimNames.profile(newDimNames != null)) { - result.setDimNames(newDimNames); + setDimNamesNode.setDimNames(result, newDimNames); } } return result; @@ -267,7 +270,8 @@ public abstract class CopyAttributesNode extends RBaseNode { @Cached("createBinaryProfile()") ConditionProfile hasNames, // @Cached("createBinaryProfile()") ConditionProfile hasDimNames, @Cached("create()") GetDimAttributeNode getLeftDimsNode, - @Cached("create()") GetDimAttributeNode getRightDimsNode) { + @Cached("create()") GetDimAttributeNode getRightDimsNode, + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { if (LOG) { log("copyAttributes: >"); countLarger++; @@ -299,7 +303,7 @@ public abstract class CopyAttributesNode extends RBaseNode { if (left != result) { RList newDimNames = left.getDimNames(attrLeftProfiles); if (hasDimNames.profile(newDimNames != null)) { - result.setDimNames(newDimNames); + setDimNamesNode.setDimNames(result, newDimNames); } } return result; diff --git a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java index ea270ac4e7452bfedc3b214b84b19cf8b9058ec0..9dc4681c3120d047e58004db1bc509fceb022c31 100644 --- a/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java +++ b/com.oracle.truffle.r.nodes/src/com/oracle/truffle/r/nodes/attributes/SpecialAttributesFunctions.java @@ -35,6 +35,7 @@ import com.oracle.truffle.r.runtime.RError; import com.oracle.truffle.r.runtime.RInternalError; import com.oracle.truffle.r.runtime.RRuntime; import com.oracle.truffle.r.runtime.data.RAttributable; +import com.oracle.truffle.r.runtime.data.RAttributesLayout; import com.oracle.truffle.r.runtime.data.RDataFactory; import com.oracle.truffle.r.runtime.data.RIntVector; import com.oracle.truffle.r.runtime.data.RInteger; @@ -361,14 +362,9 @@ public final class SpecialAttributesFunctions { @Specialization(insertBefore = "setAttrInAttributable") protected void resetDims(RAbstractContainer x, @SuppressWarnings("unused") RNull rnull, @Cached("create()") RemoveDimAttributeNode removeDimAttrNode, - @Cached("createBinaryProfile()") ConditionProfile vectorClassProfile, - @Cached("createClassProfile()") ValueProfile xTypeProfile) { + @Cached("create()") SetDimNamesAttributeNode setDimNamesNode) { removeDimAttrNode.execute(x); - if (vectorClassProfile.profile(x instanceof RVector)) { - ((RVector<?>) x).setDimNames(null); - } else { - xTypeProfile.profile(x).setDimNames(null); - } + setDimNamesNode.setDimNames(x, null); } @Specialization(insertBefore = "setAttrInAttributable") @@ -530,55 +526,73 @@ public final class SpecialAttributesFunctions { @Specialization(insertBefore = "setAttrInAttributable") protected void resetDimNames(RAbstractContainer x, @SuppressWarnings("unused") RNull rnull, @Cached("create()") RemoveDimNamesAttributeNode removeDimNamesAttrNode) { -// removeDimNamesAttrNode.execute(x); - x.setDimNames(null); + removeDimNamesAttrNode.execute(x); } @Specialization(insertBefore = "setAttrInAttributable") protected void setDimNamesInVector(RVector<?> x, RList newDimNames, @Cached("create()") GetDimAttributeNode getDimNode, + @Cached("create()") BranchProfile nullDimsProfile, + @Cached("create()") BranchProfile dimsLengthProfile, + @Cached("createCountingProfile()") LoopConditionProfile loopProfile, + @Cached("create()") BranchProfile invalidDimProfile, + @Cached("create()") BranchProfile nullDimProfile, + @Cached("create()") BranchProfile resizeDimsProfile, @Cached("create()") BranchProfile attrNullProfile, @Cached("createBinaryProfile()") ConditionProfile attrStorageProfile, @Cached("createClassProfile()") ValueProfile xTypeProfile) { - x.setDimNames(newDimNames); -// int[] dimensions = getDimNode.getDimensions(x); -// if (dimensions == null) { -// throw RError.error(this, RError.Message.DIMNAMES_NONARRAY); -// } -// int newDimNamesLength = newDimNames.getLength(); -// if (newDimNamesLength > dimensions.length) { -// throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_DIMS, newDimNamesLength, -// dimensions.length); -// } -// for (int i = 0; i < newDimNamesLength; i++) { -// Object dimObject = newDimNames.getDataAt(i); -// if (dimObject != RNull.instance) { -// if (dimObject instanceof String) { -// if (dimensions[i] != 1) { -// throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_EXTENT, i + 1); -// } -// } else { -// RStringVector dimVector = (RStringVector) dimObject; -// if (dimVector == null) { -// newDimNames.updateDataAt(i, RNull.instance, null); -// } else if (dimVector.getLength() != dimensions[i]) { -// throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_EXTENT, i + 1); -// } -// } -// } -// } -// -// RList resDimNames = newDimNames; -// if (newDimNamesLength < dimensions.length) { -// // resize the array and fill the missing entries with NULL-s -// resDimNames = (RList) resDimNames.copyResized(dimensions.length, true); -// resDimNames.setAttributes(newDimNames); -// for (int i = newDimNamesLength; i < dimensions.length; i++) { -// resDimNames.updateDataAt(i, RNull.instance, null); -// } -// } -// resDimNames.elementNamePrefix = RRuntime.DIMNAMES_LIST_ELEMENT_NAME_PREFIX; -// super.setAttrInAttributable(x, newDimNames, attrNullProfile, attrStorageProfile, xTypeProfile); + int[] dimensions = getDimNode.getDimensions(x); + if (dimensions == null) { + nullDimsProfile.enter(); + throw RError.error(this, RError.Message.DIMNAMES_NONARRAY); + } + int newDimNamesLength = newDimNames.getLength(); + if (newDimNamesLength > dimensions.length) { + dimsLengthProfile.enter(); + throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_DIMS, newDimNamesLength, + dimensions.length); + } + + loopProfile.profileCounted(newDimNamesLength); + for (int i = 0; loopProfile.inject(i < newDimNamesLength); i++) { + Object dimObject = newDimNames.getDataAt(i); + + if ((dimObject instanceof String && dimensions[i] != 1) || + (dimObject instanceof RStringVector && !isValidDimLength((RStringVector) dimObject, dimensions[i]))) { + invalidDimProfile.enter(); + throw RError.error(this, RError.Message.DIMNAMES_DONT_MATCH_EXTENT, i + 1); + } + + if (dimObject == null || (dimObject instanceof RStringVector && ((RStringVector) dimObject).getLength() == 0)) { + nullDimProfile.enter(); + newDimNames.updateDataAt(i, RNull.instance, null); + } + } + + RList resDimNames = newDimNames; + if (newDimNamesLength < dimensions.length) { + resizeDimsProfile.enter(); + // resize the array and fill the missing entries with NULL-s + resDimNames = (RList) resDimNames.copyResized(dimensions.length, true); + resDimNames.setAttributes(newDimNames); + for (int i = newDimNamesLength; i < dimensions.length; i++) { + resDimNames.updateDataAt(i, RNull.instance, null); + } + } + resDimNames.elementNamePrefix = RRuntime.DIMNAMES_LIST_ELEMENT_NAME_PREFIX; + + if (x.getAttributes() == null) { + attrNullProfile.enter(); + x.initAttributes(RAttributesLayout.createDimNames(resDimNames)); + return; + } + + super.setAttrInAttributable(x, resDimNames, attrNullProfile, attrStorageProfile, xTypeProfile); + } + + private static boolean isValidDimLength(RStringVector x, int expectedDim) { + int len = x.getLength(); + return len == 0 || len == expectedDim; } @Specialization(insertBefore = "setAttrInAttributable")